From 9ef9cb316b971c61b6e336d3b4e534bfde0d9edf Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 13 Oct 2023 16:27:59 +0800 Subject: [PATCH] fix webui Former-commit-id: b240b6792fdb734dd77ed54861fdde059feb1855 --- src/llmtuner/dsets/loader.py | 2 ++ src/llmtuner/webui/runner.py | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index 08c35f27..3b42f17d 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -65,6 +65,8 @@ def get_dataset( max_samples_temp = min(len(dataset), max_samples) dataset = dataset.select(range(max_samples_temp)) + # TODO: adapt to the sharegpt format + for column_name in ["prompt", "query", "response", "history"]: # align datasets if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 4ea54168..2af4d4a7 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -138,13 +138,14 @@ class Runner: lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), - resume_lora_training=( - False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training - ), + resume_lora_training=resume_lora_training, output_dir=output_dir ) args[compute_type] = True + if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] and args["quantization_bit"] is None: + args["resume_lora_training"] = False + if args["quantization_bit"] is not None: args["upcast_layernorm"] = True