From b0d49e137f98f6782d176481ebe8109a8e20a89f Mon Sep 17 00:00:00 2001 From: ZIYI ZENG <1034337098@qq.com> Date: Sat, 20 Dec 2025 01:46:00 +0800 Subject: [PATCH] [misc] Support split eval_dataset when explict set "predict_with_generate" (#9604) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/data/data_utils.py | 31 ++++++++++++++--------- src/llamafactory/data/loader.py | 28 ++++++++++---------- src/llamafactory/hparams/parser.py | 7 ++--- tests/data/processor/test_unsupervised.py | 1 + 4 files changed, 37 insertions(+), 30 deletions(-) diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index 14e261290..139425a22 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -81,41 +81,48 @@ def split_dataset( eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]], data_args: "DataArguments", seed: int, -) -> "DatasetDict": - r"""Split the dataset and returns a dataset dict containing train set and validation set. +) -> tuple[dict, dict]: + r"""Split the dataset and returns two dicts containing train set and validation set. Support both map dataset and iterable dataset. + + Returns: + train_dict: Dictionary containing training data with key "train" + eval_dict: Dictionary containing evaluation data with keys "validation" or "validation_{name}" """ if eval_dataset is not None and data_args.val_size > 1e-6: raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") - dataset_dict = {} + # the train and eval better to in dict dtype and separately return for cpode clearly and good handle outside + train_dict, eval_dict = {}, {} + if dataset is not None: if data_args.streaming: dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) if data_args.val_size > 1e-6: if data_args.streaming: - dataset_dict["validation"] = dataset.take(int(data_args.val_size)) - dataset_dict["train"] = dataset.skip(int(data_args.val_size)) + eval_dict["validation"] = dataset.take(int(data_args.val_size)) + train_dict["train"] = dataset.skip(int(data_args.val_size)) else: val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size - dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed) - dataset = dataset.train_test_split(test_size=val_size, seed=seed) - dataset_dict = {"train": dataset["train"], "validation": dataset["test"]} + split_result = dataset.train_test_split(test_size=val_size, seed=seed) + train_dict["train"] = split_result["train"] + eval_dict["validation"] = split_result["test"] else: - dataset_dict["train"] = dataset + train_dict["train"] = dataset if eval_dataset is not None: if isinstance(eval_dataset, dict): - dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()}) + for name, data in eval_dataset.items(): + eval_dict[f"validation_{name}"] = data else: if data_args.streaming: eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) - dataset_dict["validation"] = eval_dataset + eval_dict["validation"] = eval_dataset - return DatasetDict(dataset_dict) + return train_dict, eval_dict def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule": diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index cbb13455b..ad7667617 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -16,7 +16,7 @@ import os from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np -from datasets import Dataset, load_dataset, load_from_disk +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from ..extras import logging from ..extras.constants import FILEEXT2TYPE @@ -311,20 +311,22 @@ def get_dataset( ) with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)): - dataset = _get_preprocessed_dataset( - dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False - ) - if isinstance(eval_dataset, dict): - for eval_name, eval_data in eval_dataset.items(): - eval_dataset[eval_name] = _get_preprocessed_dataset( - eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True - ) - else: - eval_dataset = _get_preprocessed_dataset( - eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + # move front to make sure eval_dataset(if contain or split) can preprocessed appropriately + train_dict, eval_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed) + + if "train" in train_dict: + train_dict["train"] = _get_preprocessed_dataset( + train_dict["train"], data_args, training_args, stage, template, tokenizer, processor, is_eval=False ) - dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed) + for key in eval_dict: + eval_dict[key] = _get_preprocessed_dataset( + eval_dict[key], data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) + + # Combine train and eval dictionaries + dataset_dict = DatasetDict({**train_dict, **eval_dict}) + if data_args.tokenized_path is not None: # save tokenized dataset to disk if training_args.should_save: dataset_dict.save_to_disk(data_args.tokenized_path) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 8f204e805..a3d9ddee2 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -306,18 +306,15 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ if training_args.do_train and data_args.dataset is None: raise ValueError("Please specify dataset for training.") - if (training_args.do_eval or training_args.do_predict) and ( + if (training_args.do_eval or training_args.do_predict or training_args.predict_with_generate) and ( data_args.eval_dataset is None and data_args.val_size < 1e-6 ): - raise ValueError("Please specify dataset for evaluation.") + raise ValueError("Please make sure eval_dataset be provided or val_size >1e-6") if training_args.predict_with_generate: if is_deepspeed_zero3_enabled(): raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.") - if data_args.eval_dataset is None: - raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.") - if finetuning_args.compute_accuracy: raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.") diff --git a/tests/data/processor/test_unsupervised.py b/tests/data/processor/test_unsupervised.py index d9a9c9c41..6566f1471 100644 --- a/tests/data/processor/test_unsupervised.py +++ b/tests/data/processor/test_unsupervised.py @@ -42,6 +42,7 @@ TRAIN_ARGS = { "output_dir": "dummy_dir", "overwrite_output_dir": True, "fp16": True, + "report_to": "none", # transfromers compatibility }