diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 079fe4d0..cd5c2810 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -164,21 +164,25 @@ def _get_merged_dataset( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], -) -> Optional[Union["Dataset", "IterableDataset"]]: + merge: bool = True, +) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]: r""" Returns the merged datasets in the standard format. """ if dataset_names is None: return None - datasets = [] - for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir): + datasets = {} + for dataset_name, dataset_attr in zip(dataset_names, get_dataset_list(dataset_names, data_args.dataset_dir)): if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): raise ValueError("The dataset is not applicable in the current training stage.") - datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args)) + datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args) - return merge_dataset(datasets, data_args, seed=training_args.seed) + if merge: + return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed) + else: + return datasets def _get_dataset_processor( @@ -313,15 +317,23 @@ def get_dataset( # Load and preprocess dataset with training_args.main_process_first(desc="load dataset"): dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) - eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage) + eval_dataset = _get_merged_dataset( + data_args.eval_dataset, model_args, data_args, training_args, stage, merge=False + ) with training_args.main_process_first(desc="pre-process dataset"): dataset = _get_preprocessed_dataset( dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False ) - eval_dataset = _get_preprocessed_dataset( - eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True - ) + if isinstance(eval_dataset, dict): + for eval_name, eval_data in eval_dataset.items(): + eval_dataset[eval_name] = _get_preprocessed_dataset( + eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) + else: + eval_dataset = _get_preprocessed_dataset( + eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) if data_args.val_size > 1e-6: dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed) @@ -334,10 +346,13 @@ def get_dataset( dataset_dict["train"] = dataset if eval_dataset is not None: - if data_args.streaming: - eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + if isinstance(eval_dataset, dict): + dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()}) + else: + if data_args.streaming: + eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) - dataset_dict["validation"] = eval_dataset + dataset_dict["validation"] = eval_dataset dataset_dict = DatasetDict(dataset_dict) @@ -356,4 +371,12 @@ def get_dataset( if "validation" in dataset_dict: dataset_module["eval_dataset"] = dataset_dict["validation"] + eval_datasets_map = {} + for key in dataset_dict.keys(): + if key.startswith("validation_"): + eval_datasets_map[key[len("validation_") :]] = dataset_dict[key] + + if len(eval_datasets_map): + dataset_module["eval_dataset"] = eval_datasets_map + return dataset_module