[data] add eval_on_each_dataset arg (#7912)

This commit is contained in:
hoshi-hiyouga 2025-04-30 06:56:43 +08:00 committed by GitHub
parent c5b1d07e7c
commit 072bfe29d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 5 deletions

View File

@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True, return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]: ) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format.""" r"""Return the merged datasets in the standard format."""
if dataset_names is None: if dataset_names is None:
@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args) datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
if merge: if return_dict:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
else:
return datasets return datasets
else:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
def _get_dataset_processor( def _get_dataset_processor(
@ -303,7 +303,12 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset"):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset( eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict data_args.eval_dataset,
model_args,
data_args,
training_args,
stage,
return_dict=data_args.eval_on_each_dataset,
) )
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):

View File

@ -99,6 +99,10 @@ class DataArguments:
default=0.0, default=0.0,
metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."}, metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."},
) )
eval_on_each_dataset: bool = field(
default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."},
)
packing: Optional[bool] = field( packing: Optional[bool] = field(
default=None, default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},

View File

@ -64,6 +64,7 @@ class RayArguments:
raise ValueError( raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}" f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
) )
import pyarrow.fs as fs import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3": if self.ray_storage_filesystem == "s3":