From c955d9267c389b7e3adfb3f51a04349dbedf36ee Mon Sep 17 00:00:00 2001 From: codemayq Date: Wed, 30 Aug 2023 16:23:08 +0800 Subject: [PATCH] add dataset stage check Former-commit-id: f7fdc088d49564f7d436fd445e7e1987a9a00a0b --- src/llmtuner/tuner/core/parser.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 8651d91c..4cd90af9 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -186,6 +186,18 @@ def get_train_args( else: model_args.compute_dtype = torch.float16 + # transfer training stage to dataset stage + dataset_stage = general_args.stage + if general_args.stage == "ppo": + dataset_stage = "sft" + elif general_args.stage == "dpo": + dataset_stage = "rm" + + for dataset_attr in data_args.dataset_list: + if dataset_attr.stage and dataset_attr.stage != dataset_stage: + raise ValueError("Dataset {} is not supported for the stage {}" + .format(dataset_attr.dataset_name, general_args.stage)) + model_args.model_max_length = data_args.max_source_length + data_args.max_target_length # Log on each process the small summary: