[data] fix predict dataset (#6972)

Former-commit-id: bdb581c4a82d02458766e73c87b7a92ea31796ec
This commit is contained in:
hoshi-hiyouga 2025-02-17 20:29:40 +08:00 committed by GitHub
parent 475a355b82
commit a8c9d5663d

View File

@ -288,7 +288,7 @@ def get_dataset(
r"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
# Load tokenized dataset
# Load tokenized dataset if path exists
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
@ -303,7 +303,7 @@ def get_dataset(
if "validation" in tokenized_data:
dataset_module["eval_dataset"] = tokenized_data["validation"]
else: # Dataset
else: # single dataset
dataset_module["train_dataset"] = tokenized_data
if data_args.streaming:
@ -318,7 +318,7 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=False
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict
)
with training_args.main_process_first(desc="pre-process dataset"):
@ -356,10 +356,10 @@ def get_dataset(
dataset_dict = DatasetDict(dataset_dict)
if data_args.tokenized_path is not None:
if data_args.tokenized_path is not None: # save tokenized dataset to disk and exit
if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0)
@ -370,13 +370,13 @@ def get_dataset(
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]
else:
eval_dataset = {}
for key in dataset_dict.keys():
if key.startswith("validation_"):
eval_dataset[key[len("validation_") :]] = dataset_dict[key]
eval_datasets_map = {}
for key in dataset_dict.keys():
if key.startswith("validation_"):
eval_datasets_map[key[len("validation_") :]] = dataset_dict[key]
if len(eval_datasets_map):
dataset_module["eval_dataset"] = eval_datasets_map
if len(eval_dataset):
dataset_module["eval_dataset"] = eval_dataset
return dataset_module