From 072bfe29d3b10447b0702196028af2f0fb87ad6b Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 30 Apr 2025 06:56:43 +0800 Subject: [PATCH] [data] add eval_on_each_dataset arg (#7912) --- src/llamafactory/data/loader.py | 15 ++++++++++----- src/llamafactory/hparams/data_args.py | 4 ++++ src/llamafactory/hparams/training_args.py | 1 + 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 29eb5f17..645794c1 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -168,7 +168,7 @@ def _get_merged_dataset( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], - merge: bool = True, + return_dict: bool = False, ) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]: r"""Return the merged datasets in the standard format.""" if dataset_names is None: @@ -181,10 +181,10 @@ def _get_merged_dataset( datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args) - if merge: - return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed) - else: + if return_dict: return datasets + else: + return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed) def _get_dataset_processor( @@ -303,7 +303,12 @@ def get_dataset( with training_args.main_process_first(desc="load dataset"): dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) eval_dataset = _get_merged_dataset( - data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict + data_args.eval_dataset, + model_args, + data_args, + training_args, + stage, + return_dict=data_args.eval_on_each_dataset, ) with training_args.main_process_first(desc="pre-process dataset"): diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 96dfb391..60d3036e 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -99,6 +99,10 @@ class DataArguments: default=0.0, metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."}, ) + eval_on_each_dataset: bool = field( + default=False, + metadata={"help": "Whether or not to evaluate on each dataset separately."}, + ) packing: Optional[bool] = field( default=None, metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index fae9a6a3..b37c0a2f 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -64,6 +64,7 @@ class RayArguments: raise ValueError( f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}" ) + import pyarrow.fs as fs if self.ray_storage_filesystem == "s3":