diff --git a/src/llmtuner/data/utils.py b/src/llmtuner/data/utils.py index 83ee0610..dc189609 100644 --- a/src/llmtuner/data/utils.py +++ b/src/llmtuner/data/utils.py @@ -78,9 +78,9 @@ def split_dataset( if training_args.do_train: if data_args.val_size > 1e-6: # Split the dataset if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) val_set = dataset.take(int(data_args.val_size)) train_set = dataset.skip(int(data_args.val_size)) - dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) return {"train_dataset": train_set, "eval_dataset": val_set} else: val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size