diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index f24c6cdb..a686a0a6 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -223,13 +223,14 @@ def get_dataset( dataset_module: Dict[str, "Dataset"] = {} if "train" in dataset_dict: dataset_module["train_dataset"] = dataset_dict["train"] + if "validation" in dataset_dict: dataset_module["eval_dataset"] = dataset_dict["validation"] if data_args.streaming: dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} - return dataset_module + return dataset_module, template if data_args.streaming: raise ValueError("Turn off `streaming` when saving dataset to disk.")