mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-23 23:30:36 +08:00
[misc] Support split eval_dataset when explict set "predict_with_generate" (#9604)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -81,41 +81,48 @@ def split_dataset(
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
|
||||
data_args: "DataArguments",
|
||||
seed: int,
|
||||
) -> "DatasetDict":
|
||||
r"""Split the dataset and returns a dataset dict containing train set and validation set.
|
||||
) -> tuple[dict, dict]:
|
||||
r"""Split the dataset and returns two dicts containing train set and validation set.
|
||||
|
||||
Support both map dataset and iterable dataset.
|
||||
|
||||
Returns:
|
||||
train_dict: Dictionary containing training data with key "train"
|
||||
eval_dict: Dictionary containing evaluation data with keys "validation" or "validation_{name}"
|
||||
"""
|
||||
if eval_dataset is not None and data_args.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
||||
|
||||
dataset_dict = {}
|
||||
# the train and eval better to in dict dtype and separately return for cpode clearly and good handle outside
|
||||
train_dict, eval_dict = {}, {}
|
||||
|
||||
if dataset is not None:
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||
|
||||
if data_args.val_size > 1e-6:
|
||||
if data_args.streaming:
|
||||
dataset_dict["validation"] = dataset.take(int(data_args.val_size))
|
||||
dataset_dict["train"] = dataset.skip(int(data_args.val_size))
|
||||
eval_dict["validation"] = dataset.take(int(data_args.val_size))
|
||||
train_dict["train"] = dataset.skip(int(data_args.val_size))
|
||||
else:
|
||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||
dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
|
||||
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
|
||||
dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
|
||||
split_result = dataset.train_test_split(test_size=val_size, seed=seed)
|
||||
train_dict["train"] = split_result["train"]
|
||||
eval_dict["validation"] = split_result["test"]
|
||||
else:
|
||||
dataset_dict["train"] = dataset
|
||||
train_dict["train"] = dataset
|
||||
|
||||
if eval_dataset is not None:
|
||||
if isinstance(eval_dataset, dict):
|
||||
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
|
||||
for name, data in eval_dataset.items():
|
||||
eval_dict[f"validation_{name}"] = data
|
||||
else:
|
||||
if data_args.streaming:
|
||||
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||
|
||||
dataset_dict["validation"] = eval_dataset
|
||||
eval_dict["validation"] = eval_dataset
|
||||
|
||||
return DatasetDict(dataset_dict)
|
||||
return train_dict, eval_dict
|
||||
|
||||
|
||||
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
|
||||
|
||||
Reference in New Issue
Block a user