mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-23 23:30:36 +08:00
[misc] Support split eval_dataset when explict set "predict_with_generate" (#9604)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -81,41 +81,48 @@ def split_dataset(
|
|||||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
|
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> "DatasetDict":
|
) -> tuple[dict, dict]:
|
||||||
r"""Split the dataset and returns a dataset dict containing train set and validation set.
|
r"""Split the dataset and returns two dicts containing train set and validation set.
|
||||||
|
|
||||||
Support both map dataset and iterable dataset.
|
Support both map dataset and iterable dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
train_dict: Dictionary containing training data with key "train"
|
||||||
|
eval_dict: Dictionary containing evaluation data with keys "validation" or "validation_{name}"
|
||||||
"""
|
"""
|
||||||
if eval_dataset is not None and data_args.val_size > 1e-6:
|
if eval_dataset is not None and data_args.val_size > 1e-6:
|
||||||
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
||||||
|
|
||||||
dataset_dict = {}
|
# the train and eval better to in dict dtype and separately return for cpode clearly and good handle outside
|
||||||
|
train_dict, eval_dict = {}, {}
|
||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||||
|
|
||||||
if data_args.val_size > 1e-6:
|
if data_args.val_size > 1e-6:
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset_dict["validation"] = dataset.take(int(data_args.val_size))
|
eval_dict["validation"] = dataset.take(int(data_args.val_size))
|
||||||
dataset_dict["train"] = dataset.skip(int(data_args.val_size))
|
train_dict["train"] = dataset.skip(int(data_args.val_size))
|
||||||
else:
|
else:
|
||||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||||
dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
|
split_result = dataset.train_test_split(test_size=val_size, seed=seed)
|
||||||
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
|
train_dict["train"] = split_result["train"]
|
||||||
dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
|
eval_dict["validation"] = split_result["test"]
|
||||||
else:
|
else:
|
||||||
dataset_dict["train"] = dataset
|
train_dict["train"] = dataset
|
||||||
|
|
||||||
if eval_dataset is not None:
|
if eval_dataset is not None:
|
||||||
if isinstance(eval_dataset, dict):
|
if isinstance(eval_dataset, dict):
|
||||||
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
|
for name, data in eval_dataset.items():
|
||||||
|
eval_dict[f"validation_{name}"] = data
|
||||||
else:
|
else:
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||||
|
|
||||||
dataset_dict["validation"] = eval_dataset
|
eval_dict["validation"] = eval_dataset
|
||||||
|
|
||||||
return DatasetDict(dataset_dict)
|
return train_dict, eval_dict
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
|
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import os
|
|||||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import Dataset, load_dataset, load_from_disk
|
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
@@ -311,20 +311,22 @@ def get_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
|
with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
|
||||||
dataset = _get_preprocessed_dataset(
|
# move front to make sure eval_dataset(if contain or split) can preprocessed appropriately
|
||||||
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
train_dict, eval_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
|
||||||
)
|
|
||||||
if isinstance(eval_dataset, dict):
|
if "train" in train_dict:
|
||||||
for eval_name, eval_data in eval_dataset.items():
|
train_dict["train"] = _get_preprocessed_dataset(
|
||||||
eval_dataset[eval_name] = _get_preprocessed_dataset(
|
train_dict["train"], data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
||||||
eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
eval_dataset = _get_preprocessed_dataset(
|
|
||||||
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
|
for key in eval_dict:
|
||||||
|
eval_dict[key] = _get_preprocessed_dataset(
|
||||||
|
eval_dict[key], data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine train and eval dictionaries
|
||||||
|
dataset_dict = DatasetDict({**train_dict, **eval_dict})
|
||||||
|
|
||||||
if data_args.tokenized_path is not None: # save tokenized dataset to disk
|
if data_args.tokenized_path is not None: # save tokenized dataset to disk
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||||
|
|||||||
@@ -306,18 +306,15 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
|||||||
if training_args.do_train and data_args.dataset is None:
|
if training_args.do_train and data_args.dataset is None:
|
||||||
raise ValueError("Please specify dataset for training.")
|
raise ValueError("Please specify dataset for training.")
|
||||||
|
|
||||||
if (training_args.do_eval or training_args.do_predict) and (
|
if (training_args.do_eval or training_args.do_predict or training_args.predict_with_generate) and (
|
||||||
data_args.eval_dataset is None and data_args.val_size < 1e-6
|
data_args.eval_dataset is None and data_args.val_size < 1e-6
|
||||||
):
|
):
|
||||||
raise ValueError("Please specify dataset for evaluation.")
|
raise ValueError("Please make sure eval_dataset be provided or val_size >1e-6")
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
|
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
if data_args.eval_dataset is None:
|
|
||||||
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
|
|
||||||
|
|
||||||
if finetuning_args.compute_accuracy:
|
if finetuning_args.compute_accuracy:
|
||||||
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
|
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ TRAIN_ARGS = {
|
|||||||
"output_dir": "dummy_dir",
|
"output_dir": "dummy_dir",
|
||||||
"overwrite_output_dir": True,
|
"overwrite_output_dir": True,
|
||||||
"fp16": True,
|
"fp16": True,
|
||||||
|
"report_to": "none", # transfromers compatibility
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user