[data] add eval_on_each_dataset arg (#7912)

This commit is contained in:
hoshi-hiyouga 2025-04-30 06:56:43 +08:00 committed by GitHub
parent c5b1d07e7c
commit 072bfe29d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 5 deletions

View File

@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True,
return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format."""
if dataset_names is None:
@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
if merge:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
else:
if return_dict:
return datasets
else:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
def _get_dataset_processor(
@ -303,7 +303,12 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict
data_args.eval_dataset,
model_args,
data_args,
training_args,
stage,
return_dict=data_args.eval_on_each_dataset,
)
with training_args.main_process_first(desc="pre-process dataset"):

View File

@ -99,6 +99,10 @@ class DataArguments:
default=0.0,
metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."},
)
eval_on_each_dataset: bool = field(
default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."},
)
packing: Optional[bool] = field(
default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},

View File

@ -64,6 +64,7 @@ class RayArguments:
raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
)
import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3":