diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 41c12422..70beea05 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -59,6 +59,13 @@ def get_dataset( dataset_name=data_path, subset_name=data_name, ).to_hf_dataset() + + def map_func(example): + # do something to example + example['input'] = example['input'] or '' + return example + + dataset = dataset.ds_instance.map(map_func) else: dataset = load_dataset( path=data_path,