mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[data] add eval_on_each_dataset arg (#7912)
This commit is contained in:
parent
c5b1d07e7c
commit
072bfe29d3
@ -168,7 +168,7 @@ def _get_merged_dataset(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
merge: bool = True,
|
||||
return_dict: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
|
||||
r"""Return the merged datasets in the standard format."""
|
||||
if dataset_names is None:
|
||||
@ -181,10 +181,10 @@ def _get_merged_dataset(
|
||||
|
||||
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
|
||||
|
||||
if merge:
|
||||
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
|
||||
else:
|
||||
if return_dict:
|
||||
return datasets
|
||||
else:
|
||||
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
|
||||
|
||||
|
||||
def _get_dataset_processor(
|
||||
@ -303,7 +303,12 @@ def get_dataset(
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
||||
eval_dataset = _get_merged_dataset(
|
||||
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict
|
||||
data_args.eval_dataset,
|
||||
model_args,
|
||||
data_args,
|
||||
training_args,
|
||||
stage,
|
||||
return_dict=data_args.eval_on_each_dataset,
|
||||
)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
|
@ -99,6 +99,10 @@ class DataArguments:
|
||||
default=0.0,
|
||||
metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."},
|
||||
)
|
||||
eval_on_each_dataset: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to evaluate on each dataset separately."},
|
||||
)
|
||||
packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
||||
|
@ -64,6 +64,7 @@ class RayArguments:
|
||||
raise ValueError(
|
||||
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
|
||||
)
|
||||
|
||||
import pyarrow.fs as fs
|
||||
|
||||
if self.ray_storage_filesystem == "s3":
|
||||
|
Loading…
x
Reference in New Issue
Block a user