mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[data] fix predict dataset (#6972)
Former-commit-id: bdb581c4a82d02458766e73c87b7a92ea31796ec
This commit is contained in:
parent
475a355b82
commit
a8c9d5663d
@ -288,7 +288,7 @@ def get_dataset(
|
||||
r"""
|
||||
Gets the train dataset and optionally gets the evaluation dataset.
|
||||
"""
|
||||
# Load tokenized dataset
|
||||
# Load tokenized dataset if path exists
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||
@ -303,7 +303,7 @@ def get_dataset(
|
||||
if "validation" in tokenized_data:
|
||||
dataset_module["eval_dataset"] = tokenized_data["validation"]
|
||||
|
||||
else: # Dataset
|
||||
else: # single dataset
|
||||
dataset_module["train_dataset"] = tokenized_data
|
||||
|
||||
if data_args.streaming:
|
||||
@ -318,7 +318,7 @@ def get_dataset(
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
||||
eval_dataset = _get_merged_dataset(
|
||||
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=False
|
||||
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict
|
||||
)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
@ -356,10 +356,10 @@ def get_dataset(
|
||||
|
||||
dataset_dict = DatasetDict(dataset_dict)
|
||||
|
||||
if data_args.tokenized_path is not None:
|
||||
if data_args.tokenized_path is not None: # save tokenized dataset to disk and exit
|
||||
if training_args.should_save:
|
||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
||||
logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
|
||||
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
||||
|
||||
sys.exit(0)
|
||||
@ -370,13 +370,13 @@ def get_dataset(
|
||||
|
||||
if "validation" in dataset_dict:
|
||||
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||
else:
|
||||
eval_dataset = {}
|
||||
for key in dataset_dict.keys():
|
||||
if key.startswith("validation_"):
|
||||
eval_dataset[key[len("validation_") :]] = dataset_dict[key]
|
||||
|
||||
eval_datasets_map = {}
|
||||
for key in dataset_dict.keys():
|
||||
if key.startswith("validation_"):
|
||||
eval_datasets_map[key[len("validation_") :]] = dataset_dict[key]
|
||||
|
||||
if len(eval_datasets_map):
|
||||
dataset_module["eval_dataset"] = eval_datasets_map
|
||||
if len(eval_dataset):
|
||||
dataset_module["eval_dataset"] = eval_dataset
|
||||
|
||||
return dataset_module
|
||||
|
Loading…
x
Reference in New Issue
Block a user