From 1bea1ed8680c6fee2122218a36cad0122c5a5aee Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 00:28:53 +0800 Subject: [PATCH 01/16] support phi-3 Former-commit-id: 1a13f0555568c240f9d0bf0d94d95fbd252c07e4 --- README.md | 5 +++-- README_zh.md | 5 +++-- src/llmtuner/data/template.py | 9 +++++++++ src/llmtuner/extras/constants.py | 14 ++++++++++++++ 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index faa1c7d8..0bf9f731 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | +| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | @@ -333,7 +334,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec -### Train with LLaMA Board GUI +### Train with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio)) > [!IMPORTANT] > LLaMA Board GUI only supports training on a single GPU, please use [CLI](#command-line-interface) for distributed training. @@ -458,7 +459,7 @@ If you have a project that should be incorporated, please contact via email or c This repository is licensed under the [Apache-2.0 License](LICENSE). -Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## Citation diff --git a/README_zh.md b/README_zh.md index 1b4e3f1a..69ba2562 100644 --- a/README_zh.md +++ b/README_zh.md @@ -153,6 +153,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | +| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | @@ -333,7 +334,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl -### 利用 LLaMA Board 可视化界面训练 +### 利用 LLaMA Board 可视化界面训练(由 [Gradio](https://github.com/gradio-app/gradio) 驱动) > [!IMPORTANT] > LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行多 GPU 分布式训练。 @@ -458,7 +459,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 -使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## 引用 diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 04538510..cd567a7b 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -718,6 +718,15 @@ _register_template( ) +_register_template( + name="phi", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), + format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]), + format_separator=EmptyFormatter(slots=["<|end|>\n"]), + default_system="You are a helpful AI assistant.", +) + + _register_template( name="qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index a0e51d17..38d715f5 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -652,6 +652,20 @@ register_model_group( ) +register_model_group( + models={ + "Phi3-3.8B-4k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct", + }, + "Phi3-3.8B-128k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct", + }, + }, + module="qkv_proj", + template="phi", +) + + register_model_group( models={ "Qwen-1.8B": { From 34ecad4af818846078dfb723c86f21e94ac0684e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 01:30:16 +0800 Subject: [PATCH 02/16] fix #3347 #3387 Former-commit-id: 707f0b1d5d42b8e2c5b783c7783f65dfa9890a68 --- src/llmtuner/chat/vllm_engine.py | 18 ++++++++++++------ src/llmtuner/model/__init__.py | 3 ++- src/llmtuner/model/loader.py | 23 +++++++++++++++++------ 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index e924ef6e..67a19b68 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -2,9 +2,9 @@ import uuid from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence from ..data import get_template_and_fix_tokenizer -from ..extras.misc import get_device_count +from ..extras.misc import get_device_count, infer_optim_dtype from ..extras.packages import is_vllm_available -from ..model import load_tokenizer +from ..model import load_config, load_tokenizer from .base_engine import BaseEngine, Response @@ -23,10 +23,20 @@ class VllmEngine(BaseEngine): finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", ) -> None: + config = load_config(model_args) # may download model from ms hub + load_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + self.can_generate = finetuning_args.stage == "sft" + self.tokenizer = load_tokenizer(model_args) + self.tokenizer.padding_side = "left" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.generating_args = generating_args.to_dict() + engine_args = AsyncEngineArgs( model=model_args.model_name_or_path, trust_remote_code=True, + download_dir=model_args.cache_dir, + dtype=str(load_dtype).split(".")[-1], max_model_len=model_args.vllm_maxlen, tensor_parallel_size=get_device_count() or 1, gpu_memory_utilization=model_args.vllm_gpu_util, @@ -35,10 +45,6 @@ class VllmEngine(BaseEngine): enforce_eager=model_args.vllm_enforce_eager, ) self.model = AsyncLLMEngine.from_engine_args(engine_args) - self.tokenizer = load_tokenizer(model_args) - self.tokenizer.padding_side = "left" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) - self.generating_args = generating_args.to_dict() async def _generate( self, diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index 1eaf4271..e0b1c9cd 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -1,8 +1,9 @@ -from .loader import load_model, load_tokenizer +from .loader import load_config, load_model, load_tokenizer from .utils import find_all_linear_modules, load_valuehead_params __all__ = [ + "load_config", "load_model", "load_tokenizer", "load_valuehead_params", diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 4935dd52..57f5a763 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -12,7 +12,7 @@ from .utils import load_valuehead_params, register_autoclass if TYPE_CHECKING: - from transformers import PreTrainedModel, PreTrainedTokenizer + from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from ..hparams import FinetuningArguments, ModelArguments @@ -21,6 +21,11 @@ logger = get_logger(__name__) def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: + r""" + Gets arguments to load config/tokenizer/model. + + Note: including inplace operation of model_args. + """ model_args.model_name_or_path = try_download_model_from_ms(model_args) return { "trust_remote_code": True, @@ -32,9 +37,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": r""" - Loads pretrained tokenizer. Must before load_model. - - Note: including inplace operation of model_args. + Loads pretrained tokenizer. """ init_kwargs = _get_init_kwargs(model_args) try: @@ -57,6 +60,14 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": return tokenizer +def load_config(model_args: "ModelArguments") -> "PretrainedConfig": + r""" + Loads model config. + """ + init_kwargs = _get_init_kwargs(model_args) + return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) + + def load_model( tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", @@ -65,10 +76,10 @@ def load_model( add_valuehead: bool = False, ) -> "PreTrainedModel": r""" - Loads pretrained model. Must after load_tokenizer. + Loads pretrained model. """ init_kwargs = _get_init_kwargs(model_args) - config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) + config = load_config(model_args) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) model = None From 80c8586534f739993e0675432c847eedc76611b0 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 02:18:44 +0800 Subject: [PATCH 03/16] reenable sdpa and fast tok by default Former-commit-id: 07737a3d2d026c973ab964f948953d6ce0e1f2a9 --- README.md | 4 +-- README_zh.md | 4 +-- requirements.txt | 1 + src/llmtuner/extras/packages.py | 19 ++++++++--- src/llmtuner/hparams/model_args.py | 8 ++--- src/llmtuner/model/patcher.py | 49 ++++++++++++++++++++++------ src/llmtuner/webui/components/top.py | 2 +- src/llmtuner/webui/runner.py | 4 +-- 8 files changed, 64 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 0bf9f731..970dd8fc 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,6 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See `examples/extras/mod` for usage. -[24/04/19] We supported **Meta Llama 3** model series. - [24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage. [24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison). @@ -112,7 +110,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models. -[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. +[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. [23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings. diff --git a/README_zh.md b/README_zh.md index 69ba2562..583c89ca 100644 --- a/README_zh.md +++ b/README_zh.md @@ -72,8 +72,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd [24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 `examples/extras/mod`。 -[24/04/19] 我们支持了 **Meta Llama 3** 系列模型。 - [24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`。 [24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 @@ -112,7 +110,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd [23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。 -[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。 +[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn fa2` 参数以启用 FlashAttention-2。 [23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。 diff --git a/requirements.txt b/requirements.txt index 3928d28d..ecba3ce1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ fastapi sse-starlette matplotlib fire +packaging diff --git a/src/llmtuner/extras/packages.py b/src/llmtuner/extras/packages.py index 8494cb2c..aeeba084 100644 --- a/src/llmtuner/extras/packages.py +++ b/src/llmtuner/extras/packages.py @@ -1,16 +1,23 @@ import importlib.metadata import importlib.util +from typing import TYPE_CHECKING + +from packaging import version + + +if TYPE_CHECKING: + from packaging.version import Version def _is_package_available(name: str) -> bool: return importlib.util.find_spec(name) is not None -def _get_package_version(name: str) -> str: +def _get_package_version(name: str) -> "Version": try: - return importlib.metadata.version(name) + return version.parse(importlib.metadata.version(name)) except Exception: - return "0.0.0" + return version.parse("0.0.0") def is_fastapi_availble(): @@ -18,7 +25,7 @@ def is_fastapi_availble(): def is_flash_attn2_available(): - return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2") + return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0") def is_galore_available(): @@ -49,6 +56,10 @@ def is_rouge_available(): return _is_package_available("rouge_chinese") +def is_sdpa_available(): + return _get_package_version("torch") > version.parse("2.1.1") + + def is_starlette_available(): return _is_package_available("sse_starlette") diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 0e42033f..eb6366d9 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -22,7 +22,7 @@ class ModelArguments: metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, ) use_fast_tokenizer: bool = field( - default=False, + default=True, metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, ) resize_vocab: bool = field( @@ -61,9 +61,9 @@ class ModelArguments: default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) - flash_attn: bool = field( - default=False, - metadata={"help": "Enable FlashAttention for faster training."}, + flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field( + default="auto", + metadata={"help": "Enable FlashAttention for faster training and inference."}, ) shift_attn: bool = field( default=False, diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 53616dd9..6c79992a 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -15,7 +15,7 @@ from transformers.utils.versions import require_version from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES from ..extras.logging import get_logger from ..extras.misc import get_current_device, infer_optim_dtype -from ..extras.packages import is_flash_attn2_available +from ..extras.packages import is_flash_attn2_available, is_sdpa_available from ..extras.patches.llama_patch import apply_llama_patch from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable @@ -62,18 +62,45 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: - if model_args.flash_attn: - if not is_flash_attn2_available(): - logger.warning("FlashAttention2 is not installed.") + if model_args.flash_attn == "auto": + return + + elif model_args.flash_attn == "off": + requested_attn_implementation = "eager" + + elif model_args.flash_attn == "sdpa": + if not is_sdpa_available(): + logger.warning("Torch>=2.1.1 is required for SDPA attention.") return - logger.info("Using FlashAttention-2 for faster training and inference.") - if getattr(config, "model_type", None) == "internlm2": # special case for custom models - setattr(config, "attn_implementation", "flash_attention_2") - else: - setattr(config, "_attn_implementation", "flash_attention_2") + requested_attn_implementation = "sdpa" + elif model_args.flash_attn == "fa2": + if not is_flash_attn2_available(): + logger.warning("FlashAttention-2 is not installed.") + return + + requested_attn_implementation = "flash_attention_2" else: - setattr(config, "_attn_implementation", "eager") + raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) + + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + setattr(config, "attn_implementation", requested_attn_implementation) + else: + setattr(config, "_attn_implementation", requested_attn_implementation) + + +def _print_attn_implementation(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + attn_implementation = getattr(config, "attn_implementation", None) + else: + attn_implementation = getattr(config, "_attn_implementation", None) + + if attn_implementation == "flash_attention_2": + logger.info("Using FlashAttention-2 for faster training and inference.") + elif attn_implementation == "sdpa": + logger.info("Using torch SDPA for faster training and inference.") + else: + logger.info("Using vanilla Attention implementation.") def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: @@ -365,6 +392,8 @@ def patch_model( add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock) + _print_attn_implementation(model.config) + try: model.add_model_tags(["llama-factory"]) except Exception: diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 6cbf6e0d..c67d7cc5 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -33,7 +33,7 @@ def create_top() -> Dict[str, "Component"]: quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none") template = gr.Dropdown(choices=list(templates.keys()), value="default") rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none") - booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none") + booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none") model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then( get_model_path, [model_name], [model_path], queue=False diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index ec493c96..b64a015c 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -67,7 +67,7 @@ class Runner: if not model_path: return ALERTS["err_no_path"][lang] - if len(dataset) == 0: + if not dataset: return ALERTS["err_no_dataset"][lang] if not from_preview and self.demo_mode: @@ -122,7 +122,7 @@ class Runner: quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, - flash_attn=(get("top.booster") == "flashattn"), + flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), dataset_dir=get("train.dataset_dir"), dataset=",".join(get("train.dataset")), From 8465e54d3897ed5c90ba71123d5c628330905faa Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 03:02:23 +0800 Subject: [PATCH 04/16] refactor patcher Former-commit-id: aa2b79eb23c60825e6601b0b8cc6b59e3f566b2d --- src/llmtuner/extras/constants.py | 2 + src/llmtuner/model/__init__.py | 2 +- src/llmtuner/model/adapter.py | 3 +- src/llmtuner/model/loader.py | 2 +- src/llmtuner/model/patcher.py | 325 +----------------- .../patches => model/utils}/__init__.py | 0 src/llmtuner/model/utils/attention.py | 55 +++ src/llmtuner/model/utils/checkpointing.py | 94 +++++ src/llmtuner/model/utils/embedding.py | 56 +++ .../utils/longlora.py} | 151 +++++++- .../model/{utils.py => utils/misc.py} | 74 +--- src/llmtuner/model/utils/moe.py | 39 +++ src/llmtuner/model/utils/quantization.py | 146 ++++++++ src/llmtuner/model/utils/rope.py | 43 +++ 14 files changed, 598 insertions(+), 394 deletions(-) rename src/llmtuner/{extras/patches => model/utils}/__init__.py (100%) create mode 100644 src/llmtuner/model/utils/attention.py create mode 100644 src/llmtuner/model/utils/checkpointing.py create mode 100644 src/llmtuner/model/utils/embedding.py rename src/llmtuner/{extras/patches/llama_patch.py => model/utils/longlora.py} (58%) rename src/llmtuner/model/{utils.py => utils/misc.py} (61%) create mode 100644 src/llmtuner/model/utils/moe.py create mode 100644 src/llmtuner/model/utils/quantization.py create mode 100644 src/llmtuner/model/utils/rope.py diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 38d715f5..0a29f971 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -47,6 +47,8 @@ TRAINING_STAGES = { STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"] +SUPPORTED_CLASS_FOR_S2ATTN = ["llama"] + V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index e0b1c9cd..1824f084 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -1,5 +1,5 @@ from .loader import load_config, load_model, load_tokenizer -from .utils import find_all_linear_modules, load_valuehead_params +from .utils.misc import find_all_linear_modules, load_valuehead_params __all__ = [ diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index f73666d5..efc63cde 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -5,7 +5,8 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model from transformers.integrations import is_deepspeed_zero3_enabled from ..extras.logging import get_logger -from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules +from .utils.misc import find_all_linear_modules, find_expanded_modules +from .utils.quantization import QuantizationMethod if TYPE_CHECKING: diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 57f5a763..b8558542 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -8,7 +8,7 @@ from ..extras.logging import get_logger from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms from .adapter import init_adapter from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model -from .utils import load_valuehead_params, register_autoclass +from .utils.misc import load_valuehead_params, register_autoclass if TYPE_CHECKING: diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 6c79992a..c0166a8a 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -1,23 +1,20 @@ -import math -import os -import random -from contextlib import nullcontext from types import MethodType -from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Dict import torch -from datasets import load_dataset from peft import PeftModel -from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase +from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.integrations import is_deepspeed_zero3_enabled -from transformers.utils.versions import require_version -from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES from ..extras.logging import get_logger -from ..extras.misc import get_current_device, infer_optim_dtype -from ..extras.packages import is_flash_attn2_available, is_sdpa_available -from ..extras.patches.llama_patch import apply_llama_patch -from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable +from ..extras.misc import infer_optim_dtype +from .utils.attention import configure_attn_implementation, print_attn_implementation +from .utils.checkpointing import prepare_model_for_training +from .utils.embedding import resize_embedding_layer +from .utils.longlora import configure_longlora +from .utils.moe import add_z3_leaf_module +from .utils.quantization import configure_quantization +from .utils.rope import configure_rope if TYPE_CHECKING: @@ -28,282 +25,6 @@ if TYPE_CHECKING: logger = get_logger(__name__) -SUPPORTED_CLASS_FOR_S2ATTN = ["llama"] - - -def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: - r""" - Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 - TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 - """ - if os.path.isfile(model_args.export_quantization_dataset): - data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) - data_files = model_args.export_quantization_dataset - else: - data_path = model_args.export_quantization_dataset - data_files = None - - dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) - maxlen = model_args.export_quantization_maxlen - - samples = [] - for _ in range(model_args.export_quantization_nsamples): - while True: - sample_idx = random.randint(0, len(dataset) - 1) - sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") - if sample["input_ids"].size(1) >= maxlen: - break # TODO: fix large maxlen - - word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) - input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] - samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) - - return samples - - -def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: - if model_args.flash_attn == "auto": - return - - elif model_args.flash_attn == "off": - requested_attn_implementation = "eager" - - elif model_args.flash_attn == "sdpa": - if not is_sdpa_available(): - logger.warning("Torch>=2.1.1 is required for SDPA attention.") - return - - requested_attn_implementation = "sdpa" - elif model_args.flash_attn == "fa2": - if not is_flash_attn2_available(): - logger.warning("FlashAttention-2 is not installed.") - return - - requested_attn_implementation = "flash_attention_2" - else: - raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) - - if getattr(config, "model_type", None) == "internlm2": # special case for custom models - setattr(config, "attn_implementation", requested_attn_implementation) - else: - setattr(config, "_attn_implementation", requested_attn_implementation) - - -def _print_attn_implementation(config: "PretrainedConfig") -> None: - if getattr(config, "model_type", None) == "internlm2": # special case for custom models - attn_implementation = getattr(config, "attn_implementation", None) - else: - attn_implementation = getattr(config, "_attn_implementation", None) - - if attn_implementation == "flash_attention_2": - logger.info("Using FlashAttention-2 for faster training and inference.") - elif attn_implementation == "sdpa": - logger.info("Using torch SDPA for faster training and inference.") - else: - logger.info("Using vanilla Attention implementation.") - - -def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: - if model_args.rope_scaling is None: - return - - if not hasattr(config, "rope_scaling"): - logger.warning("Current model does not support RoPE scaling.") - return - - if is_trainable: - if model_args.rope_scaling == "dynamic": - logger.warning( - "Dynamic NTK scaling may not work well with fine-tuning. " - "See: https://github.com/huggingface/transformers/pull/24653" - ) - - current_max_length = getattr(config, "max_position_embeddings", None) - if current_max_length and model_args.model_max_length > current_max_length: - scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) - else: - logger.warning("Input length is smaller than max length. Consider increase input length.") - scaling_factor = 1.0 - else: - scaling_factor = 2.0 - - setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) - logger.info( - "Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor) - ) - - -def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: - if not is_trainable or not model_args.shift_attn: - return - - if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: - setattr(config, "group_size_ratio", 0.25) - apply_llama_patch() - logger.info("Using shift short attention with group_size_ratio=1/4.") - else: - logger.warning("Current model does not support shift short attention.") - - -def _configure_quantization( - config: "PretrainedConfig", - tokenizer: "PreTrainedTokenizer", - model_args: "ModelArguments", - init_kwargs: Dict[str, Any], -) -> None: - r""" - Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) - """ - if getattr(config, "quantization_config", None): # ptq - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.") - - if model_args.quantization_device_map != "auto": - init_kwargs["device_map"] = {"": get_current_device()} - - quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) - quant_method = quantization_config.get("quant_method", "") - - if quant_method == QuantizationMethod.GPTQ: - require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") - quantization_config.pop("disable_exllama", None) # remove deprecated args - quantization_config["use_exllama"] = False # disable exllama - - if quant_method == QuantizationMethod.AWQ: - require_version("autoawq", "To fix: pip install autoawq") - - if quant_method == QuantizationMethod.AQLM: - require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") - require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") - quantization_config["bits"] = 2 - - quant_bits = quantization_config.get("bits", "?") - logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) - - elif model_args.export_quantization_bit is not None: # auto-gptq - require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") - require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") - from accelerate.utils import get_max_memory - - if getattr(config, "model_type", None) == "chatglm": - raise ValueError("ChatGLM model is not supported.") - - init_kwargs["quantization_config"] = GPTQConfig( - bits=model_args.export_quantization_bit, - tokenizer=tokenizer, - dataset=_get_quantization_dataset(tokenizer, model_args), - ) - init_kwargs["device_map"] = "auto" - init_kwargs["max_memory"] = get_max_memory() - logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) - - elif model_args.quantization_bit is not None: # bnb - if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) - - elif model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - init_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=model_args.compute_dtype, - bnb_4bit_use_double_quant=model_args.double_quantization, - bnb_4bit_quant_type=model_args.quantization_type, - bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora - ) - - if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto": - if model_args.quantization_bit != 4: - raise ValueError("Only 4-bit quantized model can use auto device map.") - - require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") - require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0") - require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") - else: - init_kwargs["device_map"] = {"": get_current_device()} - - logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) - - -def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int): - embedding_dim = embed_weight.size(1) - avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) - noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) - noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) - embed_weight[-num_new_tokens:] = avg_weight + noise_weight - - -def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: - r""" - Resize token embeddings. - """ - if is_deepspeed_zero3_enabled(): - import deepspeed # type: ignore - - params = [model.get_input_embeddings().weight] - if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: - params.append(model.get_output_embeddings().weight) - - context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) - else: - context_maybe_zero3 = nullcontext() - - with context_maybe_zero3: - current_embedding_size = model.get_input_embeddings().weight.size(0) - - if len(tokenizer) > current_embedding_size: - if not isinstance(model.get_output_embeddings(), torch.nn.Linear): - logger.warning("Current model does not support resizing token embeddings.") - return - - model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) - with context_maybe_zero3: - new_embedding_size = model.get_input_embeddings().weight.size(0) - num_new_tokens = new_embedding_size - current_embedding_size - _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) - _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) - - logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) - - -def _fp32_forward_post_hook( - module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" -) -> "torch.Tensor": - return output.to(torch.float32) - - -def _prepare_model_for_training( - model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head" -) -> None: - r""" - Includes: - (1) cast the layernorm in fp32 - (2) make output embedding layer require grads - (3) add the upcasting of the lm_head in fp32 - Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72 - """ - if model_args.upcast_layernorm: - logger.info("Upcasting layernorm weights in float32.") - for name, param in model.named_parameters(): - if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): - param.data = param.data.to(torch.float32) - - if not model_args.disable_gradient_checkpointing: - if not getattr(model, "supports_gradient_checkpointing", False): - logger.warning("Current model does not support gradient checkpointing.") - else: - # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) - # According to: https://github.com/huggingface/transformers/issues/28339 - model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) - setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled - logger.info("Gradient checkpointing enabled.") - - if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: - logger.info("Upcasting lm_head outputs in float32.") - output_layer = getattr(model, output_layer_name) - if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: - output_layer.register_forward_hook(_fp32_forward_post_hook) def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: @@ -321,10 +42,10 @@ def patch_config( if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) - _configure_attn_implementation(config, model_args) - _configure_rope(config, model_args, is_trainable) - _configure_longlora(config, model_args, is_trainable) - _configure_quantization(config, tokenizer, model_args, init_kwargs) + configure_attn_implementation(config, model_args) + configure_rope(config, model_args, is_trainable) + configure_longlora(config, model_args, is_trainable) + configure_quantization(config, tokenizer, model_args, init_kwargs) if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) @@ -377,22 +98,14 @@ def patch_model( setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) if model_args.resize_vocab: - _resize_embedding_layer(model, tokenizer) + resize_embedding_layer(model, tokenizer) if is_trainable: - _prepare_model_for_training(model, model_args) + prepare_model_for_training(model, model_args) + add_z3_leaf_module(model) - if getattr(model.config, "model_type", None) == "mixtral": - from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - - add_z3_leaf_module(model, MixtralSparseMoeBlock) - - if getattr(model.config, "model_type", None) == "qwen2moe": - from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock - - add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock) - - _print_attn_implementation(model.config) + if not model_args.use_unsloth: + print_attn_implementation(model.config) try: model.add_model_tags(["llama-factory"]) diff --git a/src/llmtuner/extras/patches/__init__.py b/src/llmtuner/model/utils/__init__.py similarity index 100% rename from src/llmtuner/extras/patches/__init__.py rename to src/llmtuner/model/utils/__init__.py diff --git a/src/llmtuner/model/utils/attention.py b/src/llmtuner/model/utils/attention.py new file mode 100644 index 00000000..f4686489 --- /dev/null +++ b/src/llmtuner/model/utils/attention.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +from ...extras.logging import get_logger +from ...extras.packages import is_flash_attn2_available, is_sdpa_available + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: + if model_args.flash_attn == "auto": + return + + elif model_args.flash_attn == "off": + requested_attn_implementation = "eager" + + elif model_args.flash_attn == "sdpa": + if not is_sdpa_available(): + logger.warning("Torch>=2.1.1 is required for SDPA attention.") + return + + requested_attn_implementation = "sdpa" + elif model_args.flash_attn == "fa2": + if not is_flash_attn2_available(): + logger.warning("FlashAttention-2 is not installed.") + return + + requested_attn_implementation = "flash_attention_2" + else: + raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) + + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + setattr(config, "attn_implementation", requested_attn_implementation) + else: + setattr(config, "_attn_implementation", requested_attn_implementation) + + +def print_attn_implementation(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + attn_implementation = getattr(config, "attn_implementation", None) + else: + attn_implementation = getattr(config, "_attn_implementation", None) + + if attn_implementation == "flash_attention_2": + logger.info("Using FlashAttention-2 for faster training and inference.") + elif attn_implementation == "sdpa": + logger.info("Using torch SDPA for faster training and inference.") + else: + logger.info("Using vanilla Attention implementation.") diff --git a/src/llmtuner/model/utils/checkpointing.py b/src/llmtuner/model/utils/checkpointing.py new file mode 100644 index 00000000..e0657be8 --- /dev/null +++ b/src/llmtuner/model/utils/checkpointing.py @@ -0,0 +1,94 @@ +import inspect +from functools import partial +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +import torch + +from ...extras.constants import LAYERNORM_NAMES +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def _gradient_checkpointing_enable( + self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None +) -> None: + r""" + Activates gradient checkpointing for the current model. + + Modification of the original method to enable gradient checkpointing for block-wise optimizer. + """ + from torch.utils.checkpoint import checkpoint + + if not self.supports_gradient_checkpointing: + raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) + + def custom_gradient_checkpointing_func(func, *args, **kwargs): + module: "torch.nn.Module" = func.__self__ + + if any(param.requires_grad for param in module.parameters()): + for arg in args: + if torch.is_tensor(arg) and torch.is_floating_point(arg): + arg.requires_grad_(True) + + return gradient_checkpointing_func(func, *args, **kwargs) + + if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format + self.apply(partial(self._set_gradient_checkpointing, value=True)) + self.enable_input_require_grads() + logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") + else: # have already enabled input require gradients + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) + + +def _fp32_forward_post_hook( + module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" +) -> "torch.Tensor": + return output.to(torch.float32) + + +def prepare_model_for_training( + model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head" +) -> None: + r""" + Includes: + (1) cast the layernorm in fp32 + (2) make output embedding layer require grads + (3) add the upcasting of the lm_head in fp32 + Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72 + """ + if model_args.upcast_layernorm: + logger.info("Upcasting layernorm weights in float32.") + for name, param in model.named_parameters(): + if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): + param.data = param.data.to(torch.float32) + + if not model_args.disable_gradient_checkpointing: + if not getattr(model, "supports_gradient_checkpointing", False): + logger.warning("Current model does not support gradient checkpointing.") + else: + # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) + # According to: https://github.com/huggingface/transformers/issues/28339 + model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled + logger.info("Gradient checkpointing enabled.") + + if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: + logger.info("Upcasting lm_head outputs in float32.") + output_layer = getattr(model, output_layer_name) + if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: + output_layer.register_forward_hook(_fp32_forward_post_hook) diff --git a/src/llmtuner/model/utils/embedding.py b/src/llmtuner/model/utils/embedding.py new file mode 100644 index 00000000..7759fc0f --- /dev/null +++ b/src/llmtuner/model/utils/embedding.py @@ -0,0 +1,56 @@ +import math +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import torch +from transformers.integrations import is_deepspeed_zero3_enabled + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + +logger = get_logger(__name__) + + +def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None: + embedding_dim = embed_weight.size(1) + avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) + noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) + noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) + embed_weight[-num_new_tokens:] = avg_weight + noise_weight + + +def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: + r""" + Resize token embeddings. + """ + if is_deepspeed_zero3_enabled(): + import deepspeed # type: ignore + + params = [model.get_input_embeddings().weight] + if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: + params.append(model.get_output_embeddings().weight) + + context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) + else: + context_maybe_zero3 = nullcontext() + + with context_maybe_zero3: + current_embedding_size = model.get_input_embeddings().weight.size(0) + + if len(tokenizer) > current_embedding_size: + if not isinstance(model.get_output_embeddings(), torch.nn.Linear): + logger.warning("Current model does not support resizing token embeddings.") + return + + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) + with context_maybe_zero3: + new_embedding_size = model.get_input_embeddings().weight.size(0) + num_new_tokens = new_embedding_size - current_embedding_size + _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) + _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) + + logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/model/utils/longlora.py similarity index 58% rename from src/llmtuner/extras/patches/llama_patch.py rename to src/llmtuner/model/utils/longlora.py index 6a90c41a..c3740a73 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/model/utils/longlora.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple import torch import torch.nn as nn @@ -7,19 +7,28 @@ from transformers.models.llama.modeling_llama import ( Cache, LlamaAttention, LlamaFlashAttention2, + LlamaSdpaAttention, apply_rotary_pos_emb, repeat_kv, ) from transformers.utils import logging from transformers.utils.versions import require_version +from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + logger = logging.get_logger(__name__) # Modified from: -# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py -def llama_torch_attn_forward( +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_attention_forward( self: "LlamaAttention", hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -39,10 +48,11 @@ def llama_torch_attn_forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + past_key_value = getattr(self, "past_key_value", past_key_value) + if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -69,8 +79,9 @@ def llama_torch_attn_forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -97,8 +108,8 @@ def llama_torch_attn_forward( # Modified from: -# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py -def llama_flash_attn_forward( +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_flash_attention_2_forward( self: "LlamaFlashAttention2", hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -117,7 +128,6 @@ def llama_flash_attn_forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -134,9 +144,10 @@ def llama_flash_attn_forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) - key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) - value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 @@ -192,7 +203,115 @@ def llama_flash_attn_forward( return attn_output, attn_weights, past_key_value -def apply_llama_patch() -> None: - require_version("transformers==4.39.3", "To fix: pip install transformers==4.39.3") - LlamaAttention.forward = llama_torch_attn_forward - LlamaFlashAttention2.forward = llama_flash_attn_forward +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_sdpa_attention_forward( + self: "LlamaSdpaAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention") + return llama_attention_forward( + self, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: torch.Tensor) -> torch.Tensor: + state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, :groupsz] + + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), + ) + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def _apply_llama_patch() -> None: + require_version("transformers==4.40.0", "To fix: pip install transformers==4.40.0") + LlamaAttention.forward = llama_attention_forward + LlamaFlashAttention2.forward = llama_flash_attention_2_forward + LlamaSdpaAttention.forward = llama_sdpa_attention_forward + + +def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.shift_attn: + return + + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: + setattr(config, "group_size_ratio", 0.25) + _apply_llama_patch() + logger.info("Using shift short attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shift short attention.") diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils/misc.py similarity index 61% rename from src/llmtuner/model/utils.py rename to src/llmtuner/model/utils/misc.py index 51dbca8e..57e772f7 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils/misc.py @@ -1,51 +1,23 @@ -import inspect -from enum import Enum, unique -from functools import partial -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List import torch from transformers import PreTrainedModel -from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils import cached_file -from transformers.utils.versions import require_version -from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..extras.logging import get_logger +from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ...extras.logging import get_logger +from .quantization import QuantizationMethod if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer - from ..hparams import ModelArguments + from ...hparams import ModelArguments logger = get_logger(__name__) -@unique -class QuantizationMethod(str, Enum): - r""" - Borrowed from `transformers.utils.quantization_config.QuantizationMethod`. - """ - - BITS_AND_BYTES = "bitsandbytes" - GPTQ = "gptq" - AWQ = "awq" - AQLM = "aqlm" - QUANTO = "quanto" - - -def add_z3_leaf_module(model: "PreTrainedModel", module: "torch.nn.Module") -> None: - r""" - Sets module as a leaf module to skip partitioning in deepspeed zero3. - """ - if is_deepspeed_zero3_enabled(): - require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") - from deepspeed.utils import set_z3_leaf_modules # type: ignore - - set_z3_leaf_modules(model, [module]) - - def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: r""" Finds all available modules to apply lora or galore. @@ -102,42 +74,6 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n return module_names -def gradient_checkpointing_enable( - self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None -) -> None: - r""" - Activates gradient checkpointing for the current model. - - Modification of the original method to enable gradient checkpointing for block-wise optimizer. - """ - from torch.utils.checkpoint import checkpoint - - if not self.supports_gradient_checkpointing: - raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) - - if gradient_checkpointing_kwargs is None: - gradient_checkpointing_kwargs = {"use_reentrant": True} - - gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) - - def custom_gradient_checkpointing_func(func, *args, **kwargs): - module: "torch.nn.Module" = func.__self__ - - if any(param.requires_grad for param in module.parameters()): - for arg in args: - if torch.is_tensor(arg) and torch.is_floating_point(arg): - arg.requires_grad_(True) - - return gradient_checkpointing_func(func, *args, **kwargs) - - if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format - self.apply(partial(self._set_gradient_checkpointing, value=True)) - self.enable_input_require_grads() - logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") - else: # have already enabled input require gradients - self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) - - def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: r""" Loads value head parameters from Hugging Face Hub or local disk. diff --git a/src/llmtuner/model/utils/moe.py b/src/llmtuner/model/utils/moe.py new file mode 100644 index 00000000..020a8f55 --- /dev/null +++ b/src/llmtuner/model/utils/moe.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING + +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils.versions import require_version + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + +def add_z3_leaf_module(model: "PreTrainedModel") -> None: + r""" + Sets module as a leaf module to skip partitioning in deepspeed zero3. + """ + if not is_deepspeed_zero3_enabled(): + return + + require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") + from deepspeed.utils import set_z3_leaf_modules # type: ignore + + if getattr(model.config, "model_type", None) == "mixtral": + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + + if getattr(model.config, "model_type", None) == "qwen2moe": + from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock + + set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) + + if getattr(model.config, "model_type", None) == "jamba": + from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock + + set_z3_leaf_modules(model, [JambaSparseMoeBlock]) + + if getattr(model.config, "model_type", None) == "dbrx": + from transformers.models.dbrx.modeling_dbrx import DbrxFFN + + set_z3_leaf_modules(model, [DbrxFFN]) diff --git a/src/llmtuner/model/utils/quantization.py b/src/llmtuner/model/utils/quantization.py new file mode 100644 index 00000000..3cf159c1 --- /dev/null +++ b/src/llmtuner/model/utils/quantization.py @@ -0,0 +1,146 @@ +import os +import random +from enum import Enum, unique +from typing import TYPE_CHECKING, Any, Dict, List + +import torch +from datasets import load_dataset +from transformers import BitsAndBytesConfig, GPTQConfig +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils.versions import require_version + +from ...extras.constants import FILEEXT2TYPE +from ...extras.logging import get_logger +from ...extras.misc import get_current_device + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedTokenizer + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +@unique +class QuantizationMethod(str, Enum): + r""" + Borrowed from `transformers.utils.quantization_config.QuantizationMethod`. + """ + + BITS_AND_BYTES = "bitsandbytes" + GPTQ = "gptq" + AWQ = "awq" + AQLM = "aqlm" + QUANTO = "quanto" + + +def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: + r""" + Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 + TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 + """ + if os.path.isfile(model_args.export_quantization_dataset): + data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) + data_files = model_args.export_quantization_dataset + else: + data_path = model_args.export_quantization_dataset + data_files = None + + dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) + maxlen = model_args.export_quantization_maxlen + + samples = [] + for _ in range(model_args.export_quantization_nsamples): + while True: + sample_idx = random.randint(0, len(dataset) - 1) + sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") + if sample["input_ids"].size(1) >= maxlen: + break # TODO: fix large maxlen + + word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) + input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] + samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) + + return samples + + +def configure_quantization( + config: "PretrainedConfig", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + init_kwargs: Dict[str, Any], +) -> None: + r""" + Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) + """ + if getattr(config, "quantization_config", None): # ptq + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.") + + if model_args.quantization_device_map != "auto": + init_kwargs["device_map"] = {"": get_current_device()} + + quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + + if quant_method == QuantizationMethod.GPTQ: + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + quantization_config.pop("disable_exllama", None) # remove deprecated args + quantization_config["use_exllama"] = False # disable exllama + + if quant_method == QuantizationMethod.AWQ: + require_version("autoawq", "To fix: pip install autoawq") + + if quant_method == QuantizationMethod.AQLM: + require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") + require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") + quantization_config["bits"] = 2 + + quant_bits = quantization_config.get("bits", "?") + logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) + + elif model_args.export_quantization_bit is not None: # auto-gptq + require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + from accelerate.utils import get_max_memory + + if getattr(config, "model_type", None) == "chatglm": + raise ValueError("ChatGLM model is not supported.") + + init_kwargs["quantization_config"] = GPTQConfig( + bits=model_args.export_quantization_bit, + tokenizer=tokenizer, + dataset=_get_quantization_dataset(tokenizer, model_args), + ) + init_kwargs["device_map"] = "auto" + init_kwargs["max_memory"] = get_max_memory() + logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) + + elif model_args.quantization_bit is not None: # bnb + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type, + bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora + ) + + if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto": + if model_args.quantization_bit != 4: + raise ValueError("Only 4-bit quantized model can use auto device map.") + + require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") + require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0") + require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") + else: + init_kwargs["device_map"] = {"": get_current_device()} + + logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) diff --git a/src/llmtuner/model/utils/rope.py b/src/llmtuner/model/utils/rope.py new file mode 100644 index 00000000..2a4cce7a --- /dev/null +++ b/src/llmtuner/model/utils/rope.py @@ -0,0 +1,43 @@ +import math +from typing import TYPE_CHECKING + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if model_args.rope_scaling is None: + return + + if not hasattr(config, "rope_scaling"): + logger.warning("Current model does not support RoPE scaling.") + return + + if is_trainable: + if model_args.rope_scaling == "dynamic": + logger.warning( + "Dynamic NTK scaling may not work well with fine-tuning. " + "See: https://github.com/huggingface/transformers/pull/24653" + ) + + current_max_length = getattr(config, "max_position_embeddings", None) + if current_max_length and model_args.model_max_length > current_max_length: + scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) + else: + logger.warning("Input length is smaller than max length. Consider increase input length.") + scaling_factor = 1.0 + else: + scaling_factor = 2.0 + + setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) + logger.info( + "Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor) + ) From c0afc4074f794261711ad81f406cfd9b4d6b922c Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 04:46:53 +0800 Subject: [PATCH 05/16] support unsloth generate Former-commit-id: b1deb0a0b920645884e58f8206b1842c144c1c52 --- src/llmtuner/model/adapter.py | 35 +++++++---- src/llmtuner/model/loader.py | 55 +++++------------ src/llmtuner/model/utils/mod.py | 28 +++++++++ src/llmtuner/model/utils/unsloth.py | 85 ++++++++++++++++++++++++++ src/llmtuner/train/utils.py | 3 + src/llmtuner/webui/components/train.py | 2 +- 6 files changed, 155 insertions(+), 53 deletions(-) create mode 100644 src/llmtuner/model/utils/mod.py create mode 100644 src/llmtuner/model/utils/unsloth.py diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index efc63cde..d8d8eaf0 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -7,10 +7,11 @@ from transformers.integrations import is_deepspeed_zero3_enabled from ..extras.logging import get_logger from .utils.misc import find_all_linear_modules, find_expanded_modules from .utils.quantization import QuantizationMethod +from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model if TYPE_CHECKING: - from transformers.modeling_utils import PreTrainedModel + from transformers import PretrainedConfig, PreTrainedModel from ..hparams import FinetuningArguments, ModelArguments @@ -19,7 +20,11 @@ logger = get_logger(__name__) def init_adapter( - model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool + config: "PretrainedConfig", + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, ) -> "PreTrainedModel": r""" Initializes the adapters. @@ -106,6 +111,10 @@ def init_adapter( assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." is_mergeable = False + if model_args.use_unsloth: + assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter." + is_mergeable = False + if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): adapter_to_merge = model_args.adapter_name_or_path[:-1] adapter_to_resume = model_args.adapter_name_or_path[-1] @@ -122,9 +131,15 @@ def init_adapter( logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) if adapter_to_resume is not None: # resume lora training - model = PeftModel.from_pretrained( - model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder - ) + if model_args.use_unsloth: + model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) + else: + model = PeftModel.from_pretrained( + model, + adapter_to_resume, + is_trainable=is_trainable, + offload_folder=model_args.offload_folder, + ) if is_trainable and adapter_to_resume is None: # create new lora weights while training if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": @@ -152,14 +167,8 @@ def init_adapter( } if model_args.use_unsloth: - from unsloth import FastLanguageModel # type: ignore - - unsloth_peft_kwargs = { - "model": model, - "max_seq_length": model_args.model_max_length, - "use_gradient_checkpointing": "unsloth", - } - model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) + print(model) + model = get_unsloth_peft_model(model, model_args, peft_kwargs) else: lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index b8558542..06405219 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -3,12 +3,13 @@ from typing import TYPE_CHECKING, Any, Dict from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from trl import AutoModelForCausalLMWithValueHead -from ..extras.constants import MOD_SUPPORTED_MODELS from ..extras.logging import get_logger -from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms +from ..extras.misc import count_parameters, try_download_model_from_ms from .adapter import init_adapter from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model from .utils.misc import load_valuehead_params, register_autoclass +from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model +from .utils.unsloth import load_unsloth_pretrained_model if TYPE_CHECKING: @@ -83,54 +84,30 @@ def load_model( patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) model = None - if is_trainable and model_args.use_unsloth: - from unsloth import FastLanguageModel # type: ignore + lazy_load = False + if model_args.use_unsloth: + if model_args.adapter_name_or_path is not None: + lazy_load = True + elif is_trainable: + model = load_unsloth_pretrained_model(config, model_args) - unsloth_kwargs = { - "model_name": model_args.model_name_or_path, - "max_seq_length": model_args.model_max_length, - "dtype": model_args.compute_dtype, - "load_in_4bit": model_args.quantization_bit == 4, - "token": model_args.hf_hub_token, - "device_map": {"": get_current_device()}, - "rope_scaling": getattr(config, "rope_scaling", None), - "fix_tokenizer": False, - "trust_remote_code": True, - } - try: - model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) - except NotImplementedError: - logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) - model_args.use_unsloth = False - - if model_args.adapter_name_or_path: - model_args.adapter_name_or_path = None - logger.warning("Unsloth does not support loading adapters.") - - if model is None: + if model is None and not lazy_load: init_kwargs["config"] = config init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path if model_args.mixture_of_depths == "load": - from MoD import AutoMoDModelForCausalLM - - model = AutoMoDModelForCausalLM.from_pretrained(**init_kwargs) + model = load_mod_pretrained_model(**init_kwargs) else: model = AutoModelForCausalLM.from_pretrained(**init_kwargs) if model_args.mixture_of_depths == "convert": - from MoD import apply_mod_to_hf + model = convert_pretrained_model_to_mod(model, config, model_args) - if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS: - raise ValueError("Current model is not supported by mixture-of-depth.") + if not lazy_load: + patch_model(model, tokenizer, model_args, is_trainable) + register_autoclass(config, model, tokenizer) - model = apply_mod_to_hf(model) - model = model.to(model_args.compute_dtype) - - patch_model(model, tokenizer, model_args, is_trainable) - register_autoclass(config, model, tokenizer) - - model = init_adapter(model, model_args, finetuning_args, is_trainable) + model = init_adapter(config, model, model_args, finetuning_args, is_trainable) if add_valuehead: model = AutoModelForCausalLMWithValueHead.from_pretrained(model) diff --git a/src/llmtuner/model/utils/mod.py b/src/llmtuner/model/utils/mod.py new file mode 100644 index 00000000..5708a1a8 --- /dev/null +++ b/src/llmtuner/model/utils/mod.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING + +from ...extras.constants import MOD_SUPPORTED_MODELS + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel": + from MoD import AutoMoDModelForCausalLM + + return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs) + + +def convert_pretrained_model_to_mod( + model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments" +) -> "PreTrainedModel": + from MoD import apply_mod_to_hf + + if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS: + raise ValueError("Current model is not supported by mixture-of-depth.") + + model = apply_mod_to_hf(model) + model = model.to(model_args.compute_dtype) + return model diff --git a/src/llmtuner/model/utils/unsloth.py b/src/llmtuner/model/utils/unsloth.py new file mode 100644 index 00000000..6c5f506f --- /dev/null +++ b/src/llmtuner/model/utils/unsloth.py @@ -0,0 +1,85 @@ +from typing import TYPE_CHECKING, Any, Dict, Optional + +from ...extras.logging import get_logger +from ...extras.misc import get_current_device + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def _get_unsloth_kwargs( + config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" +) -> Dict[str, Any]: + return { + "model_name": model_name_or_path, + "max_seq_length": model_args.model_max_length, + "dtype": model_args.compute_dtype, + "load_in_4bit": model_args.quantization_bit == 4, + "token": model_args.hf_hub_token, + "device_map": {"": get_current_device()}, + "rope_scaling": getattr(config, "rope_scaling", None), + "fix_tokenizer": False, + "trust_remote_code": True, + "use_gradient_checkpointing": "unsloth", + } + + +def load_unsloth_pretrained_model( + config: "PretrainedConfig", model_args: "ModelArguments" +) -> Optional["PreTrainedModel"]: + r""" + Optionally loads pretrained model with unsloth. + """ + from unsloth import FastLanguageModel + + unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) + try: + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + model = None + model_args.use_unsloth = False + + return model + + +def get_unsloth_peft_model( + model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] +) -> "PreTrainedModel": + r""" + Gets the peft model for the pretrained model with unsloth. + """ + from unsloth import FastLanguageModel + + unsloth_peft_kwargs = { + "model": model, + "max_seq_length": model_args.model_max_length, + "use_gradient_checkpointing": "unsloth", + } + return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) + + +def load_unsloth_peft_model( + config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool +) -> "PreTrainedModel": + r""" + Loads peft model with unsloth. + """ + from unsloth import FastLanguageModel + + unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path, model_args) + try: + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + + if not is_trainable: + FastLanguageModel.for_inference(model) + + return model diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index fa9e36e5..27dc8eb3 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -61,6 +61,9 @@ def create_modelcard_and_push( if data_args.dataset is not None: kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")] + if model_args.use_unsloth: + kwargs["tags"] = kwargs["tags"] + ["unsloth"] + if not training_args.do_train: pass elif training_args.push_to_hub: diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 0f425bc9..7dc324af 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -138,7 +138,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1) lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1) - lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01) + lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01) loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01) create_new_adapter = gr.Checkbox() From 1f99c367b38cba69058ad598357348a6b67ee714 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 05:02:18 +0800 Subject: [PATCH 06/16] remove redundant code Former-commit-id: 667ce08b27df9452faee87348419f5f1f0c0cb2f --- src/llmtuner/model/adapter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index d8d8eaf0..af58b514 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -167,7 +167,6 @@ def init_adapter( } if model_args.use_unsloth: - print(model) model = get_unsloth_peft_model(model, model_args, peft_kwargs) else: lora_config = LoraConfig( From 612ba26c4ccc80da260eab05340dd650834ff431 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 05:10:07 +0800 Subject: [PATCH 07/16] fix bug Former-commit-id: 8f44dce08aa809bb7d4ea0bd5f48ca1c56436044 --- src/llmtuner/model/utils/unsloth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/model/utils/unsloth.py b/src/llmtuner/model/utils/unsloth.py index 6c5f506f..974b41c0 100644 --- a/src/llmtuner/model/utils/unsloth.py +++ b/src/llmtuner/model/utils/unsloth.py @@ -73,7 +73,7 @@ def load_unsloth_peft_model( """ from unsloth import FastLanguageModel - unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path, model_args) + unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) try: model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) except NotImplementedError: From 7d89abb1fd34a716c792eb61ea416cdd3fb8b060 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 05:21:18 +0800 Subject: [PATCH 08/16] fix bug Former-commit-id: 73ff9c834b069bf8b1bde75cc4daf996746050fa --- src/llmtuner/model/utils/unsloth.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/model/utils/unsloth.py b/src/llmtuner/model/utils/unsloth.py index 974b41c0..8a16409d 100644 --- a/src/llmtuner/model/utils/unsloth.py +++ b/src/llmtuner/model/utils/unsloth.py @@ -18,7 +18,7 @@ def _get_unsloth_kwargs( ) -> Dict[str, Any]: return { "model_name": model_name_or_path, - "max_seq_length": model_args.model_max_length, + "max_seq_length": model_args.model_max_length or 4096, "dtype": model_args.compute_dtype, "load_in_4bit": model_args.quantization_bit == 4, "token": model_args.hf_hub_token, @@ -34,7 +34,7 @@ def load_unsloth_pretrained_model( config: "PretrainedConfig", model_args: "ModelArguments" ) -> Optional["PreTrainedModel"]: r""" - Optionally loads pretrained model with unsloth. + Optionally loads pretrained model with unsloth. Used in training. """ from unsloth import FastLanguageModel @@ -53,7 +53,7 @@ def get_unsloth_peft_model( model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] ) -> "PreTrainedModel": r""" - Gets the peft model for the pretrained model with unsloth. + Gets the peft model for the pretrained model with unsloth. Used in training. """ from unsloth import FastLanguageModel @@ -69,12 +69,15 @@ def load_unsloth_peft_model( config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool ) -> "PreTrainedModel": r""" - Loads peft model with unsloth. + Loads peft model with unsloth. Used in both training and inference. """ from unsloth import FastLanguageModel unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) try: + if not is_trainable: + unsloth_kwargs["use_gradient_checkpointing"] = False + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) except NotImplementedError: raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) From 80f0a63f73d61fd1287846eb310f499a0fb557bf Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 05:39:52 +0800 Subject: [PATCH 09/16] add dbrx and jamba models Former-commit-id: 69eb03a8feee530000d290cd00aef28fca6d1e84 --- src/llmtuner/data/template.py | 25 +++++++++++++++++++++++++ src/llmtuner/extras/constants.py | 26 ++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index cd567a7b..efdd44f3 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -550,6 +550,31 @@ _register_template( ) +_register_template( + name="dbrx", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system=( + "You are DBRX, created by Databricks. You were last updated in December 2023. " + "You answer questions based on information available up to that point.\n" + "YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough " + "responses to more complex and open-ended questions.\nYou assist with various tasks, " + "from writing to coding (using markdown for code blocks — remember to use ``` with " + "code, JSON, and tables).\n(You do not have real-time data access or code execution " + "capabilities. You avoid stereotyping and provide balanced perspectives on " + "controversial topics. You do not provide song lyrics, poems, or news articles and " + "do not divulge details of your training data.)\nThis is your system prompt, " + "guiding your responses. Do not reference it, just respond to the user. If you find " + "yourself talking about this message, stop. You should be responding appropriately " + "and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION " + "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." + ), + stop_words=["<|im_end|>"], + replace_eos=True, +) + + _register_template( name="deepseek", format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 0a29f971..031e3e81 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -268,6 +268,22 @@ register_model_group( ) +register_model_group( + models={ + "DBRX-132B-Base": { + DownloadSource.DEFAULT: "databricks/dbrx-base", + DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base", + }, + "DBRX-132B-Chat": { + DownloadSource.DEFAULT: "databricks/dbrx-instruct", + DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct", + }, + }, + module="Wqkv", + template="dbrx", +) + + register_model_group( models={ "DeepSeek-LLM-7B-Base": { @@ -453,6 +469,16 @@ register_model_group( ) +register_model_group( + models={ + "Jambda-v0.1": { + DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1", + } + }, +) + + register_model_group( models={ "LingoWhale-8B": { From fff1fb12328048f705a9950d25b139deff4b7571 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 05:50:50 +0800 Subject: [PATCH 10/16] add olmo 1.7 Former-commit-id: 44a43ee152ba746f79bf7ef38520b29cd6a5cc2b --- README.md | 2 +- README_zh.md | 2 +- src/llmtuner/extras/constants.py | 11 ++++------- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 970dd8fc..4e87e369 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | -| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | diff --git a/README_zh.md b/README_zh.md index 583c89ca..599af301 100644 --- a/README_zh.md +++ b/README_zh.md @@ -149,7 +149,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | -| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 031e3e81..9f7d5c46 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -613,18 +613,15 @@ register_model_group( register_model_group( models={ "OLMo-1B": { - DownloadSource.DEFAULT: "allenai/OLMo-1B", + DownloadSource.DEFAULT: "allenai/OLMo-1B-hf", }, "OLMo-7B": { - DownloadSource.DEFAULT: "allenai/OLMo-7B", - DownloadSource.MODELSCOPE: "AI-ModelScope/OLMo-7B", + DownloadSource.DEFAULT: "allenai/OLMo-7B-hf", }, - "OLMo-7B-Chat": { - DownloadSource.DEFAULT: "allenai/OLMo-7B-Instruct", + "OLMo-1.7-7B": { + DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf", }, }, - module="att_proj", - template="olmo", ) From 4a854dfe273ef3519e926e084f66bbf20b3e7e2c Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 24 Apr 2024 13:53:39 +0800 Subject: [PATCH 11/16] fix inference in llamaboard Former-commit-id: f36057ea0300ab089ded568fa170682e9e19c4ee --- src/llmtuner/webui/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index b64a015c..77d5ea98 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -222,7 +222,7 @@ class Runner: quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, - flash_attn=(get("top.booster") == "flashattn"), + flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), dataset_dir=get("eval.dataset_dir"), dataset=",".join(get("eval.dataset")), From ce36e316bc5f66f908204d75252cc35908bdc36a Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 24 Apr 2024 13:54:21 +0800 Subject: [PATCH 12/16] fix webchatmodel Former-commit-id: 4c71c314fb23e0e1eb294abf0d6c4cccfb531716 --- src/llmtuner/webui/chatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index dac7dd67..ee28603e 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -72,7 +72,7 @@ class WebChatModel(ChatModel): finetuning_type=get("top.finetuning_type"), quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), - flash_attn=(get("top.booster") == "flash_attn"), + flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, infer_backend=get("infer.infer_backend"), From 7592d981b208bf613e6fe3555c31869f7d412547 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 24 Apr 2024 13:55:14 +0800 Subject: [PATCH 13/16] fix phi template Former-commit-id: e5d23c053a6d239097f87515b8cd5611b7cfa3cf --- src/llmtuner/data/template.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index efdd44f3..dd355e97 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -747,8 +747,10 @@ _register_template( name="phi", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]), - format_separator=EmptyFormatter(slots=["<|end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful AI assistant.", + stop_words=["<|end|>"], + replace_eos=True, ) From ce490c65ae87896ecc29ffc3416b6a48c4a3f845 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 23:39:31 +0800 Subject: [PATCH 14/16] support new special token #3420 Former-commit-id: 297fb8ead3daf154152d9826b49bb4d769fbaaa9 --- src/llmtuner/hparams/data_args.py | 4 ++-- src/llmtuner/hparams/generating_args.py | 4 ++-- src/llmtuner/hparams/model_args.py | 7 +++++++ src/llmtuner/hparams/parser.py | 6 +++++- src/llmtuner/model/adapter.py | 11 +++++++++++ src/llmtuner/model/loader.py | 12 ++++++++++++ src/llmtuner/model/utils/embedding.py | 6 ++++-- src/llmtuner/model/utils/rope.py | 4 ++++ 8 files changed, 47 insertions(+), 7 deletions(-) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index f5f75c77..1e0cd08c 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -26,11 +26,11 @@ class DataArguments: ) cutoff_len: int = field( default=1024, - metadata={"help": "The cutoff length of the model inputs after tokenization."}, + metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, ) reserved_label_len: int = field( default=1, - metadata={"help": "The minimum cutoff length reserved for label after tokenization."}, + metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."}, ) train_on_prompt: bool = field( default=False, diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py index 70dabb3e..e792c003 100644 --- a/src/llmtuner/hparams/generating_args.py +++ b/src/llmtuner/hparams/generating_args.py @@ -31,11 +31,11 @@ class GeneratingArguments: metadata={"help": "Number of beams for beam search. 1 means no beam search."}, ) max_length: int = field( - default=512, + default=1024, metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}, ) max_new_tokens: int = field( - default=512, + default=1024, metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, ) repetition_penalty: float = field( diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index eb6366d9..b60492a0 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -33,6 +33,10 @@ class ModelArguments: default=False, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, ) + new_special_tokens: Optional[str] = field( + default=None, + metadata={"help": "Special tokens to be added into the tokenizer."}, + ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, @@ -177,6 +181,9 @@ class ModelArguments: if self.adapter_name_or_path is not None: # support merging multiple lora weights self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] + if self.new_special_tokens is not None: # support multiple special tokens + self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] + assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index 0d286819..a7d0a17f 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -67,6 +67,9 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if finetuning_args.finetuning_type != "lora": raise ValueError("Quantization is only compatible with the LoRA method.") + if model_args.resize_vocab: + raise ValueError("Cannot resize embedding layers of a quantized model.") + if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter: raise ValueError("Cannot create new adapter upon a quantized model.") @@ -199,10 +202,11 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if ( training_args.do_train and finetuning_args.finetuning_type == "lora" + and model_args.quantization_bit is None and model_args.resize_vocab and finetuning_args.additional_target is None ): - logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.") + logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.") if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): logger.warning("We recommend enable `upcast_layernorm` in quantized training.") diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index af58b514..d43e00f0 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -157,6 +157,17 @@ def init_adapter( ): raise ValueError("DoRA is not compatible with PTQ-quantized models.") + if model_args.resize_vocab and finetuning_args.additional_target is None: + input_embeddings = model.get_input_embeddings() + output_embeddings = model.get_output_embeddings() + module_names = set() + for name, module in model.named_modules(): + if module in [input_embeddings, output_embeddings]: + module_names.add(name.split(".")[-1]) + + finetuning_args.additional_target = module_names + logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) + peft_kwargs = { "r": finetuning_args.lora_rank, "target_modules": target_modules, diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 06405219..54048cc5 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -39,6 +39,8 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": r""" Loads pretrained tokenizer. + + Note: including inplace operation of model_args. """ init_kwargs = _get_init_kwargs(model_args) try: @@ -57,6 +59,16 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": **init_kwargs, ) + if model_args.new_special_tokens is not None: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=model_args.new_special_tokens), + replace_additional_special_tokens=False, + ) + logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) + if num_added_tokens > 0 and not model_args.resize_vocab: + model_args.resize_vocab = True + logger.warning("New tokens have been added, changed `resize_vocab` to True.") + patch_tokenizer(tokenizer) return tokenizer diff --git a/src/llmtuner/model/utils/embedding.py b/src/llmtuner/model/utils/embedding.py index 7759fc0f..357c9cc0 100644 --- a/src/llmtuner/model/utils/embedding.py +++ b/src/llmtuner/model/utils/embedding.py @@ -42,9 +42,11 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken current_embedding_size = model.get_input_embeddings().weight.size(0) if len(tokenizer) > current_embedding_size: + if getattr(model, "quantization_method", None): + raise ValueError("Cannot resize embedding layers of a quantized model.") + if not isinstance(model.get_output_embeddings(), torch.nn.Linear): - logger.warning("Current model does not support resizing token embeddings.") - return + raise ValueError("Current model does not support resizing embedding layers.") model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) with context_maybe_zero3: diff --git a/src/llmtuner/model/utils/rope.py b/src/llmtuner/model/utils/rope.py index 2a4cce7a..9163253b 100644 --- a/src/llmtuner/model/utils/rope.py +++ b/src/llmtuner/model/utils/rope.py @@ -30,6 +30,10 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ current_max_length = getattr(config, "max_position_embeddings", None) if current_max_length and model_args.model_max_length > current_max_length: + logger.warning( + "Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length) + ) + setattr(config, "max_position_embeddings", model_args.model_max_length) scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) else: logger.warning("Input length is smaller than max length. Consider increase input length.") From 9a2178539621495f7db5b3ef6e57901c7212111a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 23:42:59 +0800 Subject: [PATCH 15/16] fix log level Former-commit-id: 7fbe8add8f449358c9815c5ba8a2052a2d874dab --- src/llmtuner/model/utils/rope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/model/utils/rope.py b/src/llmtuner/model/utils/rope.py index 9163253b..93ab8929 100644 --- a/src/llmtuner/model/utils/rope.py +++ b/src/llmtuner/model/utils/rope.py @@ -30,7 +30,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ current_max_length = getattr(config, "max_position_embeddings", None) if current_max_length and model_args.model_max_length > current_max_length: - logger.warning( + logger.info( "Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length) ) setattr(config, "max_position_embeddings", model_args.model_max_length) From b9fd197a37931ae38957bb26442121e98c85d54c Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 25 Apr 2024 00:21:34 +0800 Subject: [PATCH 16/16] update tool template Former-commit-id: aa16ff6205c48c00fe5792be1177c9aa88ed848b --- src/llmtuner/data/template.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index dd355e97..73b22eb7 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -503,6 +503,7 @@ _register_template( name="chatml", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|im_end|>", "<|im_start|>"], replace_eos=True, @@ -513,6 +514,7 @@ _register_template( name="chatml_de", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_separator=EmptyFormatter(slots=["\n"]), default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", stop_words=["<|im_end|>", "<|im_start|>"], @@ -554,6 +556,7 @@ _register_template( name="dbrx", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_separator=EmptyFormatter(slots=["\n"]), default_system=( "You are DBRX, created by Databricks. You were last updated in December 2023. " @@ -633,6 +636,9 @@ _register_template( name="gemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), format_separator=EmptyFormatter(slots=["\n"]), efficient_eos=True, force_system=True, @@ -703,6 +709,14 @@ _register_template( format_system=StringFormatter( slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"] ), + format_observation=StringFormatter( + slots=[ + ( + "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), default_system="You are a helpful assistant.", stop_words=["<|eot_id|>"], replace_eos=True, @@ -747,6 +761,7 @@ _register_template( name="phi", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]), + format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful AI assistant.", stop_words=["<|end|>"], @@ -758,6 +773,7 @@ _register_template( name="qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful assistant.", stop_words=["<|im_end|>"],