From 4e1997a34370f853a8511bbacc3351e57164e12d Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 18 Jul 2023 00:31:40 +0800 Subject: [PATCH] a monkey patch for lora_target Former-commit-id: 262252d67bbe4ebcbb315b5d7a34f9a091f8af0c --- src/llmtuner/extras/constants.py | 9 +++++++++ src/llmtuner/webui/runner.py | 2 ++ 2 files changed, 11 insertions(+) diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index be7d119d..1e6a4d9d 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -29,3 +29,12 @@ SUPPORTED_MODELS = { "InternLM-7B-Base": "internlm/internlm-7b", "InternLM-7B-Chat": "internlm/internlm-chat-7b" } + +DEFAULT_MODULE = { # will be deprecated + "LLaMA": "q_proj,v_proj", + "BLOOM": "query_key_value", + "BLOOMZ": "query_key_value", + "Falcon": "query_key_value", + "Baichuan": "W_pack", + "InternLM": "q_proj,v_proj" +} diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 599d31c3..45d8b340 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -6,6 +6,7 @@ import transformers from typing import Optional, Tuple from llmtuner.extras.callbacks import LogCallback +from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.misc import torch_gc from llmtuner.tuner import get_train_args, run_sft @@ -79,6 +80,7 @@ class Runner: model_name_or_path=model_name_or_path, do_train=True, finetuning_type=finetuning_type, + lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj", prompt_template=template, dataset=",".join(dataset), dataset_dir=dataset_dir,