From 0ad9f7f0587446873a32b282593efeb7190f4b40 Mon Sep 17 00:00:00 2001 From: SrWYG Date: Thu, 13 Feb 2025 02:19:03 +0800 Subject: [PATCH] [data] evaluate on each dataset (#5522) * [Update] loader.py , evaluate will run separate evaluations on each dataset. `If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run separate evaluations on each dataset. This can be useful to monitor how training affects other datasets or simply to get a more fine-grained evaluation` seq2seqtrainner support eval_dataset as Dict. * fix format * fix * fix --------- Co-authored-by: hiyouga Former-commit-id: 1e35967ae159038a66f3203dd0e6ec51eea9208f --- src/llamafactory/data/loader.py | 47 ++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 079fe4d0..cd5c2810 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -164,21 +164,25 @@ def _get_merged_dataset( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], -) -> Optional[Union["Dataset", "IterableDataset"]]: + merge: bool = True, +) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]: r""" Returns the merged datasets in the standard format. """ if dataset_names is None: return None - datasets = [] - for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir): + datasets = {} + for dataset_name, dataset_attr in zip(dataset_names, get_dataset_list(dataset_names, data_args.dataset_dir)): if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): raise ValueError("The dataset is not applicable in the current training stage.") - datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args)) + datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args) - return merge_dataset(datasets, data_args, seed=training_args.seed) + if merge: + return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed) + else: + return datasets def _get_dataset_processor( @@ -313,15 +317,23 @@ def get_dataset( # Load and preprocess dataset with training_args.main_process_first(desc="load dataset"): dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) - eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage) + eval_dataset = _get_merged_dataset( + data_args.eval_dataset, model_args, data_args, training_args, stage, merge=False + ) with training_args.main_process_first(desc="pre-process dataset"): dataset = _get_preprocessed_dataset( dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False ) - eval_dataset = _get_preprocessed_dataset( - eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True - ) + if isinstance(eval_dataset, dict): + for eval_name, eval_data in eval_dataset.items(): + eval_dataset[eval_name] = _get_preprocessed_dataset( + eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) + else: + eval_dataset = _get_preprocessed_dataset( + eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) if data_args.val_size > 1e-6: dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed) @@ -334,10 +346,13 @@ def get_dataset( dataset_dict["train"] = dataset if eval_dataset is not None: - if data_args.streaming: - eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + if isinstance(eval_dataset, dict): + dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()}) + else: + if data_args.streaming: + eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) - dataset_dict["validation"] = eval_dataset + dataset_dict["validation"] = eval_dataset dataset_dict = DatasetDict(dataset_dict) @@ -356,4 +371,12 @@ def get_dataset( if "validation" in dataset_dict: dataset_module["eval_dataset"] = dataset_dict["validation"] + eval_datasets_map = {} + for key in dataset_dict.keys(): + if key.startswith("validation_"): + eval_datasets_map[key[len("validation_") :]] = dataset_dict[key] + + if len(eval_datasets_map): + dataset_module["eval_dataset"] = eval_datasets_map + return dataset_module