From c60e79c12ee97dc87b390e7abe1753034c3aa8b7 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 1 Dec 2023 22:53:15 +0800 Subject: [PATCH] patch modelscope Former-commit-id: bd42c229b01a0bf3ceadb8cee5ad49a060cc2d13 --- README.md | 37 +-- README_zh.md | 37 +-- src/llmtuner/extras/constants.py | 392 +++++++++++++++++------------ src/llmtuner/extras/misc.py | 21 ++ src/llmtuner/hparams/model_args.py | 4 +- src/llmtuner/model/loader.py | 23 +- src/llmtuner/webui/common.py | 20 +- 7 files changed, 312 insertions(+), 222 deletions(-) diff --git a/README.md b/README.md index 7a4c8ee1..bc20c1ce 100644 --- a/README.md +++ b/README.md @@ -44,17 +44,23 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ![benchmark](assets/benchmark.svg) +
Definitions + - **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024) - **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024) - **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024) - We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA-Factory's LoRA tuning. +
+ ## Changelog -[23/12/01] We supported **[ModelScope Hub](https://www.modelscope.cn/models)** to accelerate model downloading. Add environment variable `USE_MODELSCOPE_HUB=1` to your command line, then you can use the model-id of ModelScope Hub. +[23/12/01] We supported downloading pre-trained models from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-models-optional) for usage. [23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`. +
Full Changelog + [23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention. [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models. @@ -79,6 +85,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models. +
+ ## Supported Models | Model | Model size | Default module | Template | @@ -231,31 +239,26 @@ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you wi pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl ``` -### Use ModelScope Models +### Use ModelScope Models (optional) -If you have trouble with downloading models from HuggingFace, we have supported ModelScope Hub. To use LLaMA-Factory together with ModelScope, please add a environment variable: +If you have trouble with downloading models from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner. -```shell -export USE_MODELSCOPE_HUB=1 +```bash +export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows ``` -> [!NOTE] -> -> Please use integers only. 0 or not set for using HuggingFace hub. Other values will be treated as use ModelScope hub. +Then you can train the corresponding model by specifying a model ID of the ModelScope Hub. (find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models)) -Then you can use LLaMA-Factory with ModelScope model-ids: - -```shell +```bash python src/train_bash.py \ - --model_name_or_path ZhipuAI/chatglm3-6b \ - ... other arguments -# You can find all model ids in this link: https://www.modelscope.cn/models + --model_name_or_path modelscope/Llama-2-7b-ms \ + ... # arguments (same as above) ``` -Web demo also supports ModelScope, after setting the environment variable please run with this command: +LLaMA Board also supports using the models on the ModelScope Hub. -```shell -CUDA_VISIBLE_DEVICES=0 python src/train_web.py +```bash +CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py ``` ### Train on a single GPU diff --git a/README_zh.md b/README_zh.md index 6a68ce30..0aa8ac0f 100644 --- a/README_zh.md +++ b/README_zh.md @@ -44,17 +44,23 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 ![benchmark](assets/benchmark.svg) +
变量定义 + - **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4,截断长度=1024) - **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4,截断长度=1024) - **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1,截断长度=1024) - 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA-Factory 的 LoRA 微调中采用 `lora_rank=32`。 +
+ ## 更新日志 -[23/12/01] 我们支持了 **[魔搭ModelHub](https://www.modelscope.cn/models)** 进行模型下载加速。在启动命令前环境变量中增加 `USE_MODELSCOPE_HUB=1` 即可开启。 +[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型。详细用法请参照 [此教程](#使用魔搭社区可跳过)。 [23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。 +
展开日志 + [23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。 [23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。 @@ -79,6 +85,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 [23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。 +
+ ## 模型 | 模型名 | 模型大小 | 默认模块 | Template | @@ -231,31 +239,26 @@ pip install -r requirements.txt pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl ``` -### 使用魔搭的模型 +### 使用魔搭社区(可跳过) -如果下载HuggingFace模型存在问题,我们已经支持了魔搭的ModelHub,只需要添加一个环境变量: +如果您在 Hugging Face 模型的下载中遇到了问题,可以通过下述方法使用魔搭社区。 -```shell -export USE_MODELSCOPE_HUB=1 +```bash +export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` ``` -> [!NOTE] -> -> 该环境变量仅支持整数,0或者不设置代表使用HuggingFace,其他值代表使用ModelScope +接着即可通过指定模型名称来训练对应的模型。(在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型) -之后就可以在命令行中指定魔搭的模型id: - -```shell +```bash python src/train_bash.py \ - --model_name_or_path ZhipuAI/chatglm3-6b \ - ... other arguments -# 在这个链接中可以看到所有可用模型: https://www.modelscope.cn/models + --model_name_or_path modelscope/Llama-2-7b-ms \ + ... # 参数同上 ``` -Web demo目前也支持了魔搭, 在设置环境变量后即可使用: +LLaMA Board 同样支持魔搭社区的模型下载。 -```shell -CUDA_VISIBLE_DEVICES=0 python src/train_web.py +```bash +CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py ``` ### 单 GPU 训练 diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 7e66d1b3..a36102a1 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -1,6 +1,7 @@ -import os +from enum import Enum from collections import defaultdict, OrderedDict -from typing import Dict, Optional, Union +from typing import Dict, Optional + CHOICES = ["A", "B", "C", "D"] @@ -20,8 +21,6 @@ SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] SUPPORTED_MODELS = OrderedDict() -ALL_OFFICIAL_MODELS = OrderedDict() - TRAINING_STAGES = { "Supervised Fine-Tuning": "sft", "Reward Modeling": "rm", @@ -30,9 +29,13 @@ TRAINING_STAGES = { "Pre-Training": "pt" } +class DownloadSource(str, Enum): + DEFAULT = "hf" + MODELSCOPE = "ms" + def register_model_group( - models: Dict[str, Union[str, Dict[str, str]]], + models: Dict[str, Dict[DownloadSource, str]], module: Optional[str] = None, template: Optional[str] = None ) -> None: @@ -42,14 +45,7 @@ def register_model_group( prefix = name.split("-")[0] else: assert prefix == name.split("-")[0], "prefix should be identical." - - ALL_OFFICIAL_MODELS[name] = [path] if isinstance(path, str) else list(path.values()) - if not int(os.environ.get('USE_MODELSCOPE_HUB', '0')): - # If path is a string, we treat it as a huggingface model-id by default. - SUPPORTED_MODELS[name] = path["hf"] if isinstance(path, dict) else path - elif isinstance(path, dict) and "ms" in path: - # Use ModelScope modelhub - SUPPORTED_MODELS[name] = path["ms"] + SUPPORTED_MODELS[name] = path if module is not None: DEFAULT_MODULE[prefix] = module if template is not None: @@ -59,16 +55,16 @@ def register_model_group( register_model_group( models={ "Baichuan-7B-Base": { - "hf": "baichuan-inc/Baichuan-7B", - "ms": "baichuan-inc/baichuan-7B", + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B", + DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B" }, "Baichuan-13B-Base": { - "hf": "baichuan-inc/Baichuan-13B-Base", - "ms": "baichuan-inc/Baichuan-13B-Base", + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base" }, "Baichuan-13B-Chat": { - "hf": "baichuan-inc/Baichuan-13B-Chat", - "ms": "baichuan-inc/Baichuan-13B-Base", + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat" } }, module="W_pack", @@ -79,20 +75,20 @@ register_model_group( register_model_group( models={ "Baichuan2-7B-Base": { - "hf": "baichuan-inc/Baichuan2-7B-Base", - "ms": "baichuan-inc/Baichuan2-7B-Base", + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base" }, "Baichuan2-13B-Base": { - "hf": "baichuan-inc/Baichuan2-13B-Base", - "ms": "baichuan-inc/Baichuan2-13B-Base", + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base" }, "Baichuan2-7B-Chat": { - "hf": "baichuan-inc/Baichuan2-7B-Chat", - "ms": "baichuan-inc/Baichuan2-7B-Chat", + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat" }, "Baichuan2-13B-Chat": { - "hf": "baichuan-inc/Baichuan2-13B-Chat", - "ms": "baichuan-inc/Baichuan2-13B-Chat", + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat" } }, module="W_pack", @@ -103,16 +99,16 @@ register_model_group( register_model_group( models={ "BLOOM-560M": { - "hf": "bigscience/bloom-560m", - "ms": "AI-ModelScope/bloom-560m", + DownloadSource.DEFAULT: "bigscience/bloom-560m", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m" }, "BLOOM-3B": { - "hf": "bigscience/bloom-3b", - "ms": "AI-ModelScope/bloom-3b", + DownloadSource.DEFAULT: "bigscience/bloom-3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b" }, "BLOOM-7B1": { - "hf": "bigscience/bloom-7b1", - "ms": "AI-ModelScope/bloom-7b1", + DownloadSource.DEFAULT: "bigscience/bloom-7b1", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1" } }, module="query_key_value" @@ -122,16 +118,16 @@ register_model_group( register_model_group( models={ "BLOOMZ-560M": { - "hf": "bigscience/bloomz-560m", - "ms": "AI-ModelScope/bloomz-560m", + DownloadSource.DEFAULT: "bigscience/bloomz-560m", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m" }, "BLOOMZ-3B": { - "hf": "bigscience/bloomz-3b", - "ms": "AI-ModelScope/bloomz-3b", + DownloadSource.DEFAULT: "bigscience/bloomz-3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b" }, "BLOOMZ-7B1-mt": { - "hf": "bigscience/bloomz-7b1-mt", - "ms": "AI-ModelScope/bloomz-7b1-mt", + DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt" } }, module="query_key_value" @@ -141,12 +137,12 @@ register_model_group( register_model_group( models={ "BlueLM-7B-Base": { - "hf": "vivo-ai/BlueLM-7B-Base", - "ms": "vivo-ai/BlueLM-7B-Base", + DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base", + DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base" }, "BlueLM-7B-Chat": { - "hf": "vivo-ai/BlueLM-7B-Chat", - "ms": "vivo-ai/BlueLM-7B-Chat", + DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat", + DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat" } }, template="bluelm" @@ -156,8 +152,8 @@ register_model_group( register_model_group( models={ "ChatGLM2-6B-Chat": { - "hf": "THUDM/chatglm2-6b", - "ms": "ZhipuAI/chatglm2-6b", + DownloadSource.DEFAULT: "THUDM/chatglm2-6b", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b" } }, module="query_key_value", @@ -168,12 +164,12 @@ register_model_group( register_model_group( models={ "ChatGLM3-6B-Base": { - "hf": "THUDM/chatglm3-6b-base", - "ms": "ZhipuAI/chatglm3-6b-base", + DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base" }, "ChatGLM3-6B-Chat": { - "hf": "THUDM/chatglm3-6b", - "ms": "ZhipuAI/chatglm3-6b", + DownloadSource.DEFAULT: "THUDM/chatglm3-6b", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b" } }, module="query_key_value", @@ -184,59 +180,105 @@ register_model_group( register_model_group( models={ "ChineseLLaMA2-1.3B": { - "hf": "hfl/chinese-llama-2-1.3b", - "ms": "AI-ModelScope/chinese-llama-2-1.3b", + DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b" }, "ChineseLLaMA2-7B": { - "hf": "hfl/chinese-llama-2-7b", - "ms": "AI-ModelScope/chinese-llama-2-7b", + DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b" }, "ChineseLLaMA2-13B": { - "hf": "hfl/chinese-llama-2-13b", - "ms": "AI-ModelScope/chinese-llama-2-13b", + DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b" }, "ChineseLLaMA2-1.3B-Chat": { - "hf": "hfl/chinese-alpaca-2-1.3b", - "ms": "AI-ModelScope/chinese-alpaca-2-1.3b", + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b" }, "ChineseLLaMA2-7B-Chat": { - "hf": "hfl/chinese-alpaca-2-7b", - "ms": "AI-ModelScope/chinese-alpaca-2-7b", + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b" }, "ChineseLLaMA2-13B-Chat": { - "hf": "hfl/chinese-alpaca-2-13b", - "ms": "AI-ModelScope/chinese-alpaca-2-13b", + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b" } }, template="llama2_zh" ) +register_model_group( + models={ + "DeepseekLLM-7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base" + }, + "DeepseekLLM-67B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base" + }, + "DeepseekLLM-7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat" + }, + "DeepseekLLM-67B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat" + } + }, + template="deepseek" +) + + +register_model_group( + models={ + "DeepseekCoder-6.7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base" + }, + "DeepseekCoder-33B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base" + }, + "DeepseekCoder-6.7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct" + }, + "DeepseekCoder-33B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct" + } + }, + template="deepseekcoder" +) + + register_model_group( models={ "Falcon-7B": { - "hf": "tiiuae/falcon-7b", - "ms": "AI-ModelScope/falcon-7b", + DownloadSource.DEFAULT: "tiiuae/falcon-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b" }, "Falcon-40B": { - "hf": "tiiuae/falcon-40b", - "ms": "AI-ModelScope/falcon-40b", + DownloadSource.DEFAULT: "tiiuae/falcon-40b", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b" }, "Falcon-180B": { - "hf": "tiiuae/falcon-180B", - "ms": "AI-ModelScope/falcon-180B", + DownloadSource.DEFAULT: "tiiuae/falcon-180b", + DownloadSource.MODELSCOPE: "modelscope/falcon-180B" }, "Falcon-7B-Chat": { - "hf": "tiiuae/falcon-7b-instruct", - "ms": "AI-ModelScope/falcon-7b-instruct", + DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct" }, "Falcon-40B-Chat": { - "hf": "tiiuae/falcon-40b-instruct", - "ms": "AI-ModelScope/falcon-40b-instruct", + DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct" }, "Falcon-180B-Chat": { - "hf": "tiiuae/falcon-180B-chat", - "ms": "AI-ModelScope/falcon-180B-chat", + DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat", + DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat" } }, module="query_key_value", @@ -247,20 +289,20 @@ register_model_group( register_model_group( models={ "InternLM-7B": { - "hf": "internlm/internlm-7b", - "ms": "Shanghai_AI_Laboratory/internlm-7b", + DownloadSource.DEFAULT: "internlm/internlm-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b" }, "InternLM-20B": { - "hf": "internlm/internlm-20b", - "ms": "Shanghai_AI_Laboratory/internlm-20b", + DownloadSource.DEFAULT: "internlm/internlm-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b" }, "InternLM-7B-Chat": { - "hf": "internlm/internlm-chat-7b", - "ms": "Shanghai_AI_Laboratory/internlm-chat-7b", + DownloadSource.DEFAULT: "internlm/internlm-chat-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b" }, "InternLM-20B-Chat": { - "hf": "internlm/internlm-chat-20b", - "ms": "Shanghai_AI_Laboratory/internlm-chat-20b", + DownloadSource.DEFAULT: "internlm/internlm-chat-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b" } }, template="intern" @@ -270,8 +312,8 @@ register_model_group( register_model_group( models={ "LingoWhale-8B": { - "hf": "deeplang-ai/LingoWhale-8B", - "ms": "DeepLang/LingoWhale-8B", + DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B", + DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B" } }, module="qkv_proj" @@ -281,20 +323,20 @@ register_model_group( register_model_group( models={ "LLaMA-7B": { - "hf": "huggyllama/llama-7b", - "ms": "skyline2006/llama-7b", + DownloadSource.DEFAULT: "huggyllama/llama-7b", + DownloadSource.MODELSCOPE: "skyline2006/llama-7b" }, "LLaMA-13B": { - "hf": "huggyllama/llama-13b", - "ms": "skyline2006/llama-13b", + DownloadSource.DEFAULT: "huggyllama/llama-13b", + DownloadSource.MODELSCOPE: "skyline2006/llama-13b" }, "LLaMA-30B": { - "hf": "huggyllama/llama-30b", - "ms": "skyline2006/llama-30b", + DownloadSource.DEFAULT: "huggyllama/llama-30b", + DownloadSource.MODELSCOPE: "skyline2006/llama-30b" }, "LLaMA-65B": { - "hf": "huggyllama/llama-65b", - "ms": "skyline2006/llama-65b", + DownloadSource.DEFAULT: "huggyllama/llama-65b", + DownloadSource.MODELSCOPE: "skyline2006/llama-65b" } } ) @@ -303,28 +345,28 @@ register_model_group( register_model_group( models={ "LLaMA2-7B": { - "hf": "meta-llama/Llama-2-7b-hf", - "ms": "modelscope/Llama-2-7b-ms", + DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms" }, "LLaMA2-13B": { - "hf": "meta-llama/Llama-2-13b-hf", - "ms": "modelscope/Llama-2-13b-ms", + DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms" }, "LLaMA2-70B": { - "hf": "meta-llama/Llama-2-70b-hf", - "ms": "modelscope/Llama-2-70b-ms", + DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms" }, "LLaMA2-7B-Chat": { - "hf": "meta-llama/Llama-2-7b-chat-hf", - "ms": "modelscope/Llama-2-7b-chat-ms", + DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms" }, "LLaMA2-13B-Chat": { - "hf": "meta-llama/Llama-2-13b-chat-hf", - "ms": "modelscope/Llama-2-13b-chat-ms", + DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms" }, "LLaMA2-70B-Chat": { - "hf": "meta-llama/Llama-2-70b-chat-hf", - "ms": "modelscope/Llama-2-70b-chat-ms", + DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms" } }, template="llama2" @@ -334,12 +376,12 @@ register_model_group( register_model_group( models={ "Mistral-7B": { - "hf": "mistralai/Mistral-7B-v0.1", - "ms": "AI-ModelScope/Mistral-7B-v0.1", + DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1" }, "Mistral-7B-Chat": { - "hf": "mistralai/Mistral-7B-Instruct-v0.1", - "ms": "AI-ModelScope/Mistral-7B-Instruct-v0.1", + DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1" } }, template="mistral" @@ -349,8 +391,8 @@ register_model_group( register_model_group( models={ "OpenChat3.5-7B-Chat": { - "hf": "openchat/openchat_3.5", - "ms": "myxiongmodel/openchat_3.5", + DownloadSource.DEFAULT: "openchat/openchat_3.5", + DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5" } }, template="openchat" @@ -360,8 +402,8 @@ register_model_group( register_model_group( models={ "Phi1.5-1.3B": { - "hf": "microsoft/phi-1_5", - "ms": "allspace/PHI_1-5", + DownloadSource.DEFAULT: "microsoft/phi-1_5", + DownloadSource.MODELSCOPE: "allspace/PHI_1-5" } }, module="Wqkv" @@ -370,37 +412,69 @@ register_model_group( register_model_group( models={ + "Qwen-1.8B": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B" + }, "Qwen-7B": { - "hf": "Qwen/Qwen-7B", - "ms": "qwen/Qwen-7B", + DownloadSource.DEFAULT: "Qwen/Qwen-7B", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B" }, "Qwen-14B": { - "hf": "Qwen/Qwen-14B", - "ms": "qwen/Qwen-14B", + DownloadSource.DEFAULT: "Qwen/Qwen-14B", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B" + }, + "Qwen-72B": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B" + }, + "Qwen-1.8B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat" }, "Qwen-7B-Chat": { - "hf": "Qwen/Qwen-7B-Chat", - "ms": "qwen/Qwen-7B-Chat", + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat" }, "Qwen-14B-Chat": { - "hf": "Qwen/Qwen-14B-Chat", - "ms": "qwen/Qwen-14B-Chat", + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat" + }, + "Qwen-72B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat" + }, + "Qwen-1.8B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8" + }, + "Qwen-1.8B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4" }, "Qwen-7B-int8-Chat": { - "hf": "Qwen/Qwen-7B-Chat-Int8", - "ms": "qwen/Qwen-7B-Chat-Int8", + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8" }, "Qwen-7B-int4-Chat": { - "hf": "Qwen/Qwen-7B-Chat-Int4", - "ms": "qwen/Qwen-7B-Chat-Int4", + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4" }, "Qwen-14B-int8-Chat": { - "hf": "Qwen/Qwen-14B-Chat-Int8", - "ms": "qwen/Qwen-14B-Chat-Int8", + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8" }, "Qwen-14B-int4-Chat": { - "hf": "Qwen/Qwen-14B-Chat-Int4", - "ms": "qwen/Qwen-14B-Chat-Int4", + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4" + }, + "Qwen-72B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8" + }, + "Qwen-72B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4" } }, module="c_attn", @@ -411,8 +485,8 @@ register_model_group( register_model_group( models={ "Skywork-13B-Base": { - "hf": "Skywork/Skywork-13B-base", - "ms": "skywork/Skywork-13B-base", + DownloadSource.DEFAULT: "Skywork/Skywork-13B-base", + DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base" } } ) @@ -421,12 +495,12 @@ register_model_group( register_model_group( models={ "Vicuna1.5-7B-Chat": { - "hf": "lmsys/vicuna-7b-v1.5", - "ms": "AI-ModelScope/vicuna-7b-v1.5", + DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5", + DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5" }, "Vicuna1.5-13B-Chat": { - "hf": "lmsys/vicuna-13b-v1.5", - "ms": "Xorbits/vicuna-13b-v1.5", + DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5", + DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5" } }, template="vicuna" @@ -436,24 +510,24 @@ register_model_group( register_model_group( models={ "XVERSE-7B": { - "hf": "xverse/XVERSE-7B", - "ms": "xverse/XVERSE-7B", + DownloadSource.DEFAULT: "xverse/XVERSE-7B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B" }, "XVERSE-13B": { - "hf": "xverse/XVERSE-13B", - "ms": "xverse/XVERSE-13B", + DownloadSource.DEFAULT: "xverse/XVERSE-13B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B" }, "XVERSE-65B": { - "hf": "xverse/XVERSE-65B", - "ms": "xverse/XVERSE-65B", + DownloadSource.DEFAULT: "xverse/XVERSE-65B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B" }, "XVERSE-7B-Chat": { - "hf": "xverse/XVERSE-7B-Chat", - "ms": "xverse/XVERSE-7B-Chat", + DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat" }, "XVERSE-13B-Chat": { - "hf": "xverse/XVERSE-13B-Chat", - "ms": "xverse/XVERSE-13B-Chat", + DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat" } }, template="xverse" @@ -463,12 +537,12 @@ register_model_group( register_model_group( models={ "Yayi-7B": { - "hf": "wenge-research/yayi-7b-llama2", - "ms": "AI-ModelScope/yayi-7b-llama2", + DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2", + DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2" }, "Yayi-13B": { - "hf": "wenge-research/yayi-13b-llama2", - "ms": "AI-ModelScope/yayi-13b-llama2", + DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2", + DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2" } }, template="yayi" @@ -478,20 +552,20 @@ register_model_group( register_model_group( models={ "Yi-6B": { - "hf": "01-ai/Yi-6B", - "ms": "01ai/Yi-6B", + DownloadSource.DEFAULT: "01-ai/Yi-6B", + DownloadSource.MODELSCOPE: "01ai/Yi-6B" }, "Yi-34B": { - "hf": "01-ai/Yi-34B", - "ms": "01ai/Yi-34B", + DownloadSource.DEFAULT: "01-ai/Yi-34B", + DownloadSource.MODELSCOPE: "01ai/Yi-34B" }, "Yi-34B-Chat": { - "hf": "01-ai/Yi-34B-Chat", - "ms": "01ai/Yi-34B-Chat", + DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat" }, "Yi-34B-int8-Chat": { - "hf": "01-ai/Yi-34B-Chat-8bits", - "ms": "01ai/Yi-34B-Chat-8bits", + DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits", + DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits" } }, template="yi" @@ -501,12 +575,12 @@ register_model_group( register_model_group( models={ "Zephyr-7B-Alpha-Chat": { - "hf": "HuggingFaceH4/zephyr-7b-alpha", - "ms": "AI-ModelScope/zephyr-7b-alpha", + DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha", + DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha" }, "Zephyr-7B-Beta-Chat": { - "hf": "HuggingFaceH4/zephyr-7b-beta", - "ms": "modelscope/zephyr-7b-beta", + DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta", + DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta" } }, template="zephyr" diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 4f754c5c..33efb7d2 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -23,6 +23,7 @@ except ImportError: if TYPE_CHECKING: from transformers import HfArgumentParser + from llmtuner.hparams import ModelArguments class AverageMeter: @@ -117,3 +118,23 @@ def torch_gc() -> None: if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() + + +def try_download_model_from_ms(model_args: "ModelArguments") -> None: + if not use_modelscope() or os.path.exists(model_args.model_name_or_path): + return + + try: + from modelscope import snapshot_download # type: ignore + revision = "master" if model_args.model_revision == "main" else model_args.model_revision + model_args.model_name_or_path = snapshot_download( + model_args.model_name_or_path, + revision=revision, + cache_dir=model_args.cache_dir + ) + except ImportError: + raise ImportError("Please install modelscope via `pip install modelscope -U`") + + +def use_modelscope() -> bool: + return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0"))) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index ebf6cafa..07903b37 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -8,8 +8,8 @@ class ModelArguments: Arguments pertaining to which model/config/tokenizer we are going to fine-tune. """ model_name_or_path: str = field( - metadata={"help": "Path to pretrained model or model identifier " - "from huggingface.co/models or modelscope.cn/models."} + metadata={"help": "Path to pretrained model or model identifier from \ + huggingface.co/models or modelscope.cn/models."} ) cache_dir: Optional[str] = field( default=None, diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 122cd7f2..87bad577 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,6 +1,4 @@ import math -import os - import torch from types import MethodType from typing import TYPE_CHECKING, Literal, Optional, Tuple @@ -23,8 +21,8 @@ try: except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1 from transformers.deepspeed import is_deepspeed_zero3_enabled -from llmtuner.extras.logging import reset_logging, get_logger -from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype +from llmtuner.extras.logging import get_logger +from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms from llmtuner.extras.packages import is_flash_attn2_available from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.hparams import FinetuningArguments @@ -58,6 +56,8 @@ def load_model_and_tokenizer( Support both training and inference. """ + try_download_model_from_ms(model_args) + config_kwargs = { "trust_remote_code": True, "cache_dir": model_args.cache_dir, @@ -65,8 +65,6 @@ def load_model_and_tokenizer( "token": model_args.hf_hub_token } - try_download_model_from_ms(model_args) - tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, use_fast=model_args.use_fast_tokenizer, @@ -232,16 +230,3 @@ def load_model_and_tokenizer( logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.") return model, tokenizer - - -def try_download_model_from_ms(model_args): - if int(os.environ.get('USE_MODELSCOPE_HUB', '0')) and not os.path.exists(model_args.model_name_or_path): - try: - from modelscope import snapshot_download - revision = model_args.model_revision - if revision == 'main': - revision = 'master' - model_args.model_name_or_path = snapshot_download(model_args.model_name_or_path, revision) - except ImportError as e: - raise ImportError(f'You are using `USE_MODELSCOPE_HUB=1` but you have no modelscope sdk installed. ' - f'Please install it by `pip install modelscope -U`') from e diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index dabfab16..ab2502e1 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -11,18 +11,17 @@ from transformers.utils import ( ADAPTER_SAFE_WEIGHTS_NAME ) - from llmtuner.extras.constants import ( DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, - ALL_OFFICIAL_MODELS, - TRAINING_STAGES + TRAINING_STAGES, + DownloadSource ) +from llmtuner.extras.misc import use_modelscope from llmtuner.hparams.data_args import DATA_CONFIG - DEFAULT_CACHE_DIR = "cache" DEFAULT_DATA_DIR = "data" DEFAULT_SAVE_DIR = "saves" @@ -66,10 +65,15 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona def get_model_path(model_name: str) -> str: user_config = load_config() - cached_path = user_config["path_dict"].get(model_name, None) - if cached_path in ALL_OFFICIAL_MODELS.get(model_name, []): - cached_path = None - return cached_path or SUPPORTED_MODELS.get(model_name, "") + path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, []) + model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, "") + if ( + use_modelscope() + and path_dict.get(DownloadSource.MODELSCOPE) + and model_path == path_dict.get(DownloadSource.DEFAULT) + ): # replace path + model_path = path_dict.get(DownloadSource.MODELSCOPE) + return model_path def get_prefix(model_name: str) -> str: