diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index cd5c2810..7782dd5b 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -288,7 +288,7 @@ def get_dataset( r""" Gets the train dataset and optionally gets the evaluation dataset. """ - # Load tokenized dataset + # Load tokenized dataset if path exists if data_args.tokenized_path is not None: if has_tokenized_data(data_args.tokenized_path): logger.warning_rank0("Loading dataset from disk will ignore other data arguments.") @@ -303,7 +303,7 @@ def get_dataset( if "validation" in tokenized_data: dataset_module["eval_dataset"] = tokenized_data["validation"] - else: # Dataset + else: # single dataset dataset_module["train_dataset"] = tokenized_data if data_args.streaming: @@ -318,7 +318,7 @@ def get_dataset( with training_args.main_process_first(desc="load dataset"): dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) eval_dataset = _get_merged_dataset( - data_args.eval_dataset, model_args, data_args, training_args, stage, merge=False + data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict ) with training_args.main_process_first(desc="pre-process dataset"): @@ -356,10 +356,10 @@ def get_dataset( dataset_dict = DatasetDict(dataset_dict) - if data_args.tokenized_path is not None: + if data_args.tokenized_path is not None: # save tokenized dataset to disk and exit if training_args.should_save: dataset_dict.save_to_disk(data_args.tokenized_path) - logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.") + logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.") logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.") sys.exit(0) @@ -370,13 +370,13 @@ def get_dataset( if "validation" in dataset_dict: dataset_module["eval_dataset"] = dataset_dict["validation"] + else: + eval_dataset = {} + for key in dataset_dict.keys(): + if key.startswith("validation_"): + eval_dataset[key[len("validation_") :]] = dataset_dict[key] - eval_datasets_map = {} - for key in dataset_dict.keys(): - if key.startswith("validation_"): - eval_datasets_map[key[len("validation_") :]] = dataset_dict[key] - - if len(eval_datasets_map): - dataset_module["eval_dataset"] = eval_datasets_map + if len(eval_dataset): + dataset_module["eval_dataset"] = eval_dataset return dataset_module