[data] fix predict dataset (#6972)

Former-commit-id: bdb581c4a82d02458766e73c87b7a92ea31796ec
This commit is contained in:
hoshi-hiyouga 2025-02-17 20:29:40 +08:00 committed by GitHub
parent 475a355b82
commit a8c9d5663d

View File

@ -288,7 +288,7 @@ def get_dataset(
r""" r"""
Gets the train dataset and optionally gets the evaluation dataset. Gets the train dataset and optionally gets the evaluation dataset.
""" """
# Load tokenized dataset # Load tokenized dataset if path exists
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.") logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
@ -303,7 +303,7 @@ def get_dataset(
if "validation" in tokenized_data: if "validation" in tokenized_data:
dataset_module["eval_dataset"] = tokenized_data["validation"] dataset_module["eval_dataset"] = tokenized_data["validation"]
else: # Dataset else: # single dataset
dataset_module["train_dataset"] = tokenized_data dataset_module["train_dataset"] = tokenized_data
if data_args.streaming: if data_args.streaming:
@ -318,7 +318,7 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset"):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset( eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=False data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict
) )
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):
@ -356,10 +356,10 @@ def get_dataset(
dataset_dict = DatasetDict(dataset_dict) dataset_dict = DatasetDict(dataset_dict)
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None: # save tokenized dataset to disk and exit
if training_args.should_save: if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path) dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.") logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.") logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0) sys.exit(0)
@ -370,13 +370,13 @@ def get_dataset(
if "validation" in dataset_dict: if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["eval_dataset"] = dataset_dict["validation"]
else:
eval_dataset = {}
for key in dataset_dict.keys():
if key.startswith("validation_"):
eval_dataset[key[len("validation_") :]] = dataset_dict[key]
eval_datasets_map = {} if len(eval_dataset):
for key in dataset_dict.keys(): dataset_module["eval_dataset"] = eval_dataset
if key.startswith("validation_"):
eval_datasets_map[key[len("validation_") :]] = dataset_dict[key]
if len(eval_datasets_map):
dataset_module["eval_dataset"] = eval_datasets_map
return dataset_module return dataset_module