diff --git a/README.md b/README.md index d6c5fa47..010ade70 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 | -| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - | +| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon | | [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | diff --git a/README_zh.md b/README_zh.md index 47342430..d8c5fe1c 100644 --- a/README_zh.md +++ b/README_zh.md @@ -57,7 +57,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 | -| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - | +| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon | | [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 26b86579..6627e95d 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -1,9 +1,11 @@ +from collections import defaultdict, OrderedDict +from typing import Dict, Optional + + IGNORE_INDEX = -100 LOG_FILE_NAME = "trainer_log.jsonl" -LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2", "ln1", "ln2"] - METHODS = ["full", "freeze", "lora"] TRAINING_STAGES = { @@ -14,79 +16,214 @@ TRAINING_STAGES = { "Pre-Training": "pt" } -SUPPORTED_MODELS = { - "LLaMA-7B": "huggyllama/llama-7b", - "LLaMA-13B": "huggyllama/llama-13b", - "LLaMA-30B": "huggyllama/llama-30b", - "LLaMA-65B": "huggyllama/llama-65b", - "LLaMA2-7B": "meta-llama/Llama-2-7b-hf", - "LLaMA2-13B": "meta-llama/Llama-2-13b-hf", - "LLaMA2-70B": "meta-llama/Llama-2-70b-hf", - "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf", - "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf", - "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf", - "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b", - "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b", - "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b", - "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b", - "BLOOM-560M": "bigscience/bloom-560m", - "BLOOM-3B": "bigscience/bloom-3b", - "BLOOM-7B1": "bigscience/bloom-7b1", - "BLOOMZ-560M": "bigscience/bloomz-560m", - "BLOOMZ-3B": "bigscience/bloomz-3b", - "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt", - "Falcon-7B": "tiiuae/falcon-7b", - "Falcon-40B": "tiiuae/falcon-40b", - "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct", - "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct", - "Baichuan-7B": "baichuan-inc/Baichuan-7B", - "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base", - "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", - "Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base", - "Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base", - "Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat", - "Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat", - "InternLM-7B": "internlm/internlm-7b", - "InternLM-20B": "internlm/internlm-20b", - "InternLM-7B-Chat": "internlm/internlm-chat-7b", - "InternLM-20B-Chat": "internlm/internlm-chat-20b", - "Qwen-7B": "Qwen/Qwen-7B", - "Qwen-14B": "Qwen/Qwen-14B", - "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", - "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", - "XVERSE-13B": "xverse/XVERSE-13B", - "XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat", - "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b", - "ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base", - "ChatGLM3-6B-Chat": "THUDM/chatglm3-6b", - "Phi1.5-1.3B": "microsoft/phi-1_5" -} +LAYERNORM_NAMES = {"norm", "ln"} -DEFAULT_MODULE = { - "LLaMA": "q_proj,v_proj", - "LLaMA2": "q_proj,v_proj", - "ChineseLLaMA2": "q_proj,v_proj", - "BLOOM": "query_key_value", - "BLOOMZ": "query_key_value", - "Falcon": "query_key_value", - "Baichuan": "W_pack", - "Baichuan2": "W_pack", - "InternLM": "q_proj,v_proj", - "Qwen": "c_attn", - "XVERSE": "q_proj,v_proj", - "ChatGLM2": "query_key_value", - "ChatGLM3": "query_key_value", - "Phi1.5": "Wqkv" -} +SUPPORTED_MODELS = OrderedDict() -DEFAULT_TEMPLATE = { - "LLaMA2": "llama2", - "ChineseLLaMA2": "llama2_zh", - "Baichuan": "baichuan", - "Baichuan2": "baichuan2", - "InternLM": "intern", - "Qwen": "chatml", - "XVERSE": "xverse", - "ChatGLM2": "chatglm2", - "ChatGLM3": "chatglm3" -} +DEFAULT_MODULE = defaultdict(str) + +DEFAULT_TEMPLATE = defaultdict(str) + + +def register_model_group( + models: Dict[str, str], + module: Optional[str] = None, + template: Optional[str] = None +) -> None: + prefix = None + for name, path in models.items(): + if prefix is None: + prefix = name.split("-")[0] + else: + assert prefix == name.split("-")[0], "prefix should be identical." + SUPPORTED_MODELS[name] = path + if module is not None: + DEFAULT_MODULE[prefix] = module + if template is not None: + DEFAULT_TEMPLATE[prefix] = template + + +register_model_group( + models={ + "Baichuan-7B-Base": "baichuan-inc/Baichuan-7B", + "Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base", + "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat" + }, + module="W_pack", + template="baichuan" +) + + +register_model_group( + models={ + "Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base", + "Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base", + "Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat", + "Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat" + }, + module="W_pack", + template="baichuan2" +) + + +register_model_group( + models={ + "BLOOM-560M": "bigscience/bloom-560m", + "BLOOM-3B": "bigscience/bloom-3b", + "BLOOM-7B1": "bigscience/bloom-7b1" + }, + module="query_key_value" +) + + +register_model_group( + models={ + "BLOOMZ-560M": "bigscience/bloomz-560m", + "BLOOMZ-3B": "bigscience/bloomz-3b", + "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt" + }, + module="query_key_value" +) + + +register_model_group( + models={ + "BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base", + "BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat" + }, + template="bluelm" +) + + +register_model_group( + models={ + "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b" + }, + module="query_key_value", + template="chatglm2" +) + + +register_model_group( + models={ + "ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base", + "ChatGLM3-6B-Chat": "THUDM/chatglm3-6b" + }, + module="query_key_value", + template="chatglm3" +) + + +register_model_group( + models={ + "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b", + "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b", + "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b", + "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b" + }, + template="llama2_zh" +) + + +register_model_group( + models={ + "Falcon-7B": "tiiuae/falcon-7b", + "Falcon-40B": "tiiuae/falcon-40b", + "Falcon-180B": "tiiuae/falcon-180B", + "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct", + "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct", + "Falcon-180B-Chat": "tiiuae/falcon-180B-chat" + }, + module="query_key_value", + template="falcon" +) + + +register_model_group( + models={ + "InternLM-7B": "internlm/internlm-7b", + "InternLM-20B": "internlm/internlm-20b", + "InternLM-7B-Chat": "internlm/internlm-chat-7b", + "InternLM-20B-Chat": "internlm/internlm-chat-20b" + }, + template="intern" +) + + +register_model_group( + models={ + "LLaMA-7B": "huggyllama/llama-7b", + "LLaMA-13B": "huggyllama/llama-13b", + "LLaMA-30B": "huggyllama/llama-30b", + "LLaMA-65B": "huggyllama/llama-65b" + } +) + + +register_model_group( + models={ + "LLaMA2-7B": "meta-llama/Llama-2-7b-hf", + "LLaMA2-13B": "meta-llama/Llama-2-13b-hf", + "LLaMA2-70B": "meta-llama/Llama-2-70b-hf", + "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf", + "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf", + "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf" + }, + template="llama2" +) + + +register_model_group( + models={ + "Mistral-7B": "mistralai/Mistral-7B-v0.1", + "Mistral-7B-Chat": "mistralai/Mistral-7B-Instruct-v0.1" + }, + template="mistral" +) + + +register_model_group( + models={ + "Phi1.5-1.3B": "microsoft/phi-1_5" + }, + module="Wqkv" +) + + +register_model_group( + models={ + "Qwen-7B": "Qwen/Qwen-7B", + "Qwen-14B": "Qwen/Qwen-14B", + "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", + "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat" + }, + module="c_attn", + template="qwen" +) + + +register_model_group( + models={ + "Skywork-13B-Base": "Skywork/Skywork-13B-base" + } +) + + +register_model_group( + models={ + "XVERSE-7B": "xverse/XVERSE-7B", + "XVERSE-13B": "xverse/XVERSE-13B", + "XVERSE-65B": "xverse/XVERSE-65B", + "XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat", + "XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat" + }, + template="xverse" +) + + +register_model_group( + models={ + "Yi-6B": "01-ai/Yi-6B", + "Yi-34B": "01-ai/Yi-34B" + } +) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 6e31fe83..bcb9ffa0 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -447,6 +447,25 @@ register_template( ) +r""" +Supports: https://huggingface.co/tiiuae/falcon-180B-chat +""" +register_template( + name="falcon", + prefix=[ + "{{system}}" + ], + prompt=[ + "User: {{query}}\nFalcon:" + ], + system="", + sep=[ + "\n" + ], + efficient_eos=True +) + + r""" Supports: https://huggingface.co/internlm/internlm-chat-7b https://huggingface.co/internlm/internlm-chat-20b diff --git a/src/llmtuner/tuner/core/utils.py b/src/llmtuner/tuner/core/utils.py index 19fe42fd..5e56513c 100644 --- a/src/llmtuner/tuner/core/utils.py +++ b/src/llmtuner/tuner/core/utils.py @@ -1,5 +1,5 @@ import torch -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.logging import get_logger @@ -56,7 +56,7 @@ def prepare_model_for_training( finetuning_args: "FinetuningArguments", output_layer_name: Optional[str] = "lm_head", use_gradient_checkpointing: Optional[bool] = True, - layernorm_names: Optional[List[str]] = LAYERNORM_NAMES + layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES ) -> "PreTrainedModel": r""" Includes: diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 5a6c16d3..6663254c 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -61,13 +61,17 @@ def get_model_path(model_name: str) -> str: return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "") +def get_prefix(model_name: str) -> str: + return model_name.split("-")[0] + + def get_module(model_name: str) -> str: - return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj") + return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj") def get_template(model_name: str) -> str: - if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE: - return DEFAULT_TEMPLATE[model_name.split("-")[0]] + if model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE: + return DEFAULT_TEMPLATE[get_prefix(model_name)] return "default"