mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
[data] fix predict dataset (#6972)
Former-commit-id: bdb581c4a82d02458766e73c87b7a92ea31796ec
This commit is contained in:
parent
475a355b82
commit
a8c9d5663d
@ -288,7 +288,7 @@ def get_dataset(
|
|||||||
r"""
|
r"""
|
||||||
Gets the train dataset and optionally gets the evaluation dataset.
|
Gets the train dataset and optionally gets the evaluation dataset.
|
||||||
"""
|
"""
|
||||||
# Load tokenized dataset
|
# Load tokenized dataset if path exists
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if has_tokenized_data(data_args.tokenized_path):
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||||
@ -303,7 +303,7 @@ def get_dataset(
|
|||||||
if "validation" in tokenized_data:
|
if "validation" in tokenized_data:
|
||||||
dataset_module["eval_dataset"] = tokenized_data["validation"]
|
dataset_module["eval_dataset"] = tokenized_data["validation"]
|
||||||
|
|
||||||
else: # Dataset
|
else: # single dataset
|
||||||
dataset_module["train_dataset"] = tokenized_data
|
dataset_module["train_dataset"] = tokenized_data
|
||||||
|
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
@ -318,7 +318,7 @@ def get_dataset(
|
|||||||
with training_args.main_process_first(desc="load dataset"):
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
||||||
eval_dataset = _get_merged_dataset(
|
eval_dataset = _get_merged_dataset(
|
||||||
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=False
|
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict
|
||||||
)
|
)
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset"):
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
@ -356,10 +356,10 @@ def get_dataset(
|
|||||||
|
|
||||||
dataset_dict = DatasetDict(dataset_dict)
|
dataset_dict = DatasetDict(dataset_dict)
|
||||||
|
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None: # save tokenized dataset to disk and exit
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||||
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
|
||||||
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
||||||
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
@ -370,13 +370,13 @@ def get_dataset(
|
|||||||
|
|
||||||
if "validation" in dataset_dict:
|
if "validation" in dataset_dict:
|
||||||
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||||
|
else:
|
||||||
|
eval_dataset = {}
|
||||||
|
for key in dataset_dict.keys():
|
||||||
|
if key.startswith("validation_"):
|
||||||
|
eval_dataset[key[len("validation_") :]] = dataset_dict[key]
|
||||||
|
|
||||||
eval_datasets_map = {}
|
if len(eval_dataset):
|
||||||
for key in dataset_dict.keys():
|
dataset_module["eval_dataset"] = eval_dataset
|
||||||
if key.startswith("validation_"):
|
|
||||||
eval_datasets_map[key[len("validation_") :]] = dataset_dict[key]
|
|
||||||
|
|
||||||
if len(eval_datasets_map):
|
|
||||||
dataset_module["eval_dataset"] = eval_datasets_map
|
|
||||||
|
|
||||||
return dataset_module
|
return dataset_module
|
||||||
|
Loading…
x
Reference in New Issue
Block a user