mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[data] evaluate on each dataset (#5522)
* [Update] loader.py , evaluate will run separate evaluations on each dataset. `If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run separate evaluations on each dataset. This can be useful to monitor how training affects other datasets or simply to get a more fine-grained evaluation` seq2seqtrainner support eval_dataset as Dict. * fix format * fix * fix --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 1e35967ae159038a66f3203dd0e6ec51eea9208f
This commit is contained in:
parent
1adb46875f
commit
0ad9f7f058
@ -164,21 +164,25 @@ def _get_merged_dataset(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
merge: bool = True,
|
||||
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]:
|
||||
r"""
|
||||
Returns the merged datasets in the standard format.
|
||||
"""
|
||||
if dataset_names is None:
|
||||
return None
|
||||
|
||||
datasets = []
|
||||
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
|
||||
datasets = {}
|
||||
for dataset_name, dataset_attr in zip(dataset_names, get_dataset_list(dataset_names, data_args.dataset_dir)):
|
||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||
|
||||
datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
|
||||
|
||||
return merge_dataset(datasets, data_args, seed=training_args.seed)
|
||||
if merge:
|
||||
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
|
||||
else:
|
||||
return datasets
|
||||
|
||||
|
||||
def _get_dataset_processor(
|
||||
@ -313,15 +317,23 @@ def get_dataset(
|
||||
# Load and preprocess dataset
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
||||
eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
|
||||
eval_dataset = _get_merged_dataset(
|
||||
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=False
|
||||
)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
dataset = _get_preprocessed_dataset(
|
||||
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
||||
)
|
||||
eval_dataset = _get_preprocessed_dataset(
|
||||
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||
)
|
||||
if isinstance(eval_dataset, dict):
|
||||
for eval_name, eval_data in eval_dataset.items():
|
||||
eval_dataset[eval_name] = _get_preprocessed_dataset(
|
||||
eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||
)
|
||||
else:
|
||||
eval_dataset = _get_preprocessed_dataset(
|
||||
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||
)
|
||||
|
||||
if data_args.val_size > 1e-6:
|
||||
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
|
||||
@ -334,10 +346,13 @@ def get_dataset(
|
||||
dataset_dict["train"] = dataset
|
||||
|
||||
if eval_dataset is not None:
|
||||
if data_args.streaming:
|
||||
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
if isinstance(eval_dataset, dict):
|
||||
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
|
||||
else:
|
||||
if data_args.streaming:
|
||||
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
|
||||
dataset_dict["validation"] = eval_dataset
|
||||
dataset_dict["validation"] = eval_dataset
|
||||
|
||||
dataset_dict = DatasetDict(dataset_dict)
|
||||
|
||||
@ -356,4 +371,12 @@ def get_dataset(
|
||||
if "validation" in dataset_dict:
|
||||
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||
|
||||
eval_datasets_map = {}
|
||||
for key in dataset_dict.keys():
|
||||
if key.startswith("validation_"):
|
||||
eval_datasets_map[key[len("validation_") :]] = dataset_dict[key]
|
||||
|
||||
if len(eval_datasets_map):
|
||||
dataset_module["eval_dataset"] = eval_datasets_map
|
||||
|
||||
return dataset_module
|
||||
|
Loading…
x
Reference in New Issue
Block a user