diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 2eedfa9d..33e2c5f7 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -285,7 +285,7 @@ register_template( "Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, " "racist, sexist, toxic, dangerous, or illegal content. " - "Please ensure that your responses are socially unbiased and positive in nature.\n" + "Please ensure that your responses are socially unbiased and positive in nature.\n\n" "If a question does not make any sense, or is not factually coherent, " "explain why instead of answering something not correct. " "If you don't know the answer to a question, please don't share false information." diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 8651d91c..4cd90af9 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -186,6 +186,18 @@ def get_train_args( else: model_args.compute_dtype = torch.float16 + # transfer training stage to dataset stage + dataset_stage = general_args.stage + if general_args.stage == "ppo": + dataset_stage = "sft" + elif general_args.stage == "dpo": + dataset_stage = "rm" + + for dataset_attr in data_args.dataset_list: + if dataset_attr.stage and dataset_attr.stage != dataset_stage: + raise ValueError("Dataset {} is not supported for the stage {}" + .format(dataset_attr.dataset_name, general_args.stage)) + model_args.model_max_length = data_args.max_source_length + data_args.max_target_length # Log on each process the small summary: