From a62cba3d05040024cc985e60785d80f61ef0deec Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 22 Apr 2025 00:25:51 +0800 Subject: [PATCH] [example] add bash usage (#7794) --- README.md | 25 +++++----- README_zh.md | 25 +++++----- examples/README.md | 20 +++++--- examples/README_zh.md | 20 +++++--- examples/train_lora/llama3_lora_sft.sh | 36 +++++++++++++++ src/llamafactory/hparams/model_args.py | 20 ++++---- src/llamafactory/hparams/parser.py | 2 +- .../model/model_utils/quantization.py | 1 + src/llamafactory/model/patcher.py | 26 +++++------ src/llamafactory/third_party/__init__.py | 0 src/llamafactory/third_party/muon/muon.py | 30 +++++------- src/llamafactory/train/trainer_utils.py | 31 +++++-------- tests/model/model_utils/test_add_tokens.py | 46 +++++++++++++++++++ 13 files changed, 184 insertions(+), 98 deletions(-) create mode 100644 examples/train_lora/llama3_lora_sft.sh create mode 100644 src/llamafactory/third_party/__init__.py create mode 100644 tests/model/model_utils/test_add_tokens.py diff --git a/README.md b/README.md index 296c6f79..f11eb057 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ Choose your path: - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. -- **Advanced algorithms**: [Muon](https://github.com/KellerJordan/Muon), [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. +- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA. - **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc. @@ -107,7 +107,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog -[25/04/16] We supported **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [Juanxi Tian](https://tianshijing.github.io)'s PR. +[25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR. + +[25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started. [25/04/14] We supported fine-tuning the **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** and **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** models. @@ -115,14 +117,14 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started. +
Full Changelog + [25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference. [25/03/12] We supported fine-tuning the **[Gemma 3](https://huggingface.co/blog/gemma3)** model. [25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training. -
Full Changelog - [25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage. [25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks. @@ -245,11 +247,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) | | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | -| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | +| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | -| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | +| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)\*\* | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | @@ -417,11 +419,11 @@ huggingface-cli login | Mandatory | Minimum | Recommend | | ------------ | ------- | --------- | | python | 3.9 | 3.10 | -| torch | 1.13.1 | 2.6.0 | -| transformers | 4.41.2 | 4.50.0 | +| torch | 2.0.0 | 2.6.0 | +| transformers | 4.45.0 | 4.50.0 | | datasets | 2.16.0 | 3.2.0 | | accelerate | 0.34.0 | 1.2.1 | -| peft | 0.14.0 | 0.15.0 | +| peft | 0.14.0 | 0.15.1 | | trl | 0.8.6 | 0.9.6 | | Optional | Minimum | Recommend | @@ -430,7 +432,7 @@ huggingface-cli login | deepspeed | 0.10.0 | 0.16.4 | | bitsandbytes | 0.39.0 | 0.43.1 | | vllm | 0.4.3 | 0.8.2 | -| flash-attn | 2.3.0 | 2.7.2 | +| flash-attn | 2.5.6 | 2.7.2 | ### Hardware Requirement @@ -458,7 +460,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, sglang, muon, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality +Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality > [!TIP] > Use `pip install --no-deps -e .` to resolve package conflicts. @@ -519,6 +521,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh | torch | 2.1.0 | 2.4.0 | | torch-npu | 2.1.0 | 2.4.0.post2 | | deepspeed | 0.13.2 | 0.13.2 | +| vllm-ascend | - | 0.7.3 | Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. diff --git a/README_zh.md b/README_zh.md index 61576964..7fbef1da 100644 --- a/README_zh.md +++ b/README_zh.md @@ -80,7 +80,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc - **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 -- **先进算法**:[Muon](https://github.com/KellerJordan/Muon), [GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。 +- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。 - **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。 - **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、[SwanLab](https://github.com/SwanHubX/SwanLab) 等等。 @@ -110,7 +110,9 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ## 更新日志 -[25/04/16] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@Juanxi Tian](https://tianshijing.github.io) 的 PR。 +[25/04/21] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@tianshijing](https://github.com/tianshijing) 的 PR。 + +[25/04/16] 我们支持了 **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** 模型的微调。查看 [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) 以使用。 [25/04/14] 我们支持了 **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** 和 **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** 模型的微调。 @@ -118,14 +120,14 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc [25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。 +
展开日志 + [25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。 [25/03/12] 我们支持了 **[Gemma 3](https://huggingface.co/blog/gemma3)** 模型的微调。 [25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。 -
展开日志 - [25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。 [25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。 @@ -248,11 +250,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) | | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | -| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | +| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | -| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | +| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)\*\* | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | @@ -420,11 +422,11 @@ huggingface-cli login | 必需项 | 至少 | 推荐 | | ------------ | ------- | --------- | | python | 3.9 | 3.10 | -| torch | 1.13.1 | 2.6.0 | -| transformers | 4.41.2 | 4.50.0 | +| torch | 2.0.0 | 2.6.0 | +| transformers | 4.45.0 | 4.50.0 | | datasets | 2.16.0 | 3.2.0 | | accelerate | 0.34.0 | 1.2.1 | -| peft | 0.14.0 | 0.15.0 | +| peft | 0.14.0 | 0.15.1 | | trl | 0.8.6 | 0.9.6 | | 可选项 | 至少 | 推荐 | @@ -433,7 +435,7 @@ huggingface-cli login | deepspeed | 0.10.0 | 0.16.4 | | bitsandbytes | 0.39.0 | 0.43.1 | | vllm | 0.4.3 | 0.8.2 | -| flash-attn | 2.3.0 | 2.7.2 | +| flash-attn | 2.5.6 | 2.7.2 | ### 硬件依赖 @@ -461,7 +463,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、sglang、muon, galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality +可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality > [!TIP] > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 @@ -523,6 +525,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh | torch | 2.1.0 | 2.4.0 | | torch-npu | 2.1.0 | 2.4.0.post2 | | deepspeed | 0.13.2 | 0.13.2 | +| vllm-ascend | - | 0.7.3 | 请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。 diff --git a/examples/README.md b/examples/README.md index a59045a4..d4bdf1e3 100644 --- a/examples/README.md +++ b/examples/README.md @@ -24,7 +24,13 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml Advanced usage: ```bash -CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml learning_rate=1e-5 logging_steps=1 +CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \ + learning_rate=1e-5 \ + logging_steps=1 +``` + +```bash +bash examples/train_lora/llama3_lora_sft.sh ``` ## Examples @@ -215,12 +221,6 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml ### Extras -#### Full-Parameter Fine-Tuning using Muon - -```bash -llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml -``` - #### Full-Parameter Fine-Tuning using GaLore ```bash @@ -245,6 +245,12 @@ llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml ``` +#### Full-Parameter Fine-Tuning using Muon + +```bash +llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml +``` + #### LoRA+ Fine-Tuning ```bash diff --git a/examples/README_zh.md b/examples/README_zh.md index b3b93d9d..727d6593 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -24,7 +24,13 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml 高级用法: ```bash -CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml learning_rate=1e-5 logging_steps=1 +CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \ + learning_rate=1e-5 \ + logging_steps=1 +``` + +```bash +bash examples/train_lora/llama3_lora_sft.sh ``` ## 示例 @@ -215,12 +221,6 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml ### 杂项 -#### 使用 Muon 进行全参数训练 - -```bash -llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml -``` - #### 使用 GaLore 进行全参数训练 ```bash @@ -245,6 +245,12 @@ llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml ``` +#### 使用 Muon 进行全参数训练 + +```bash +llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml +``` + #### LoRA+ 微调 ```bash diff --git a/examples/train_lora/llama3_lora_sft.sh b/examples/train_lora/llama3_lora_sft.sh new file mode 100644 index 00000000..59db2c58 --- /dev/null +++ b/examples/train_lora/llama3_lora_sft.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +set -x + +MODEL_PATH=meta-llama/Meta-Llama-3-8B-Instruct + +llamafactory-cli train \ + --model_name_or_path ${MODEL_PATH} \ + --trust_remote_code \ + --stage sft \ + --do_train \ + --finetuning_type lora \ + --lora_rank 8 \ + --lora_target all \ + --dataset identity,alpaca_en_demo \ + --template llama3 \ + --cutoff_len 2048 \ + --max_samples 1000 \ + --overwrite_cache \ + --preprocessing_num_workers 16 \ + --dataloader_num_workers 4 \ + --output_dir saves/llama3-8b/lora/sft \ + --logging_steps 10 \ + --save_steps 500 \ + --plot_loss \ + --overwrite_output_dir \ + --save_only_model false \ + --report_to none \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 8 \ + --learning_rate 1e-4 \ + --num_train_epochs 3.0 \ + --lr_scheduler_type cosine \ + --warmup_ratio 0.1 \ + --bf16 \ + --ddp_timeout 180000000 diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 8fcfcb42..07319ca8 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -65,14 +65,16 @@ class BaseModelArguments: default=False, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, ) - new_special_tokens: Optional[str] = field( + add_tokens: Optional[str] = field( + default=None, + metadata={ + "help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens." + }, + ) + add_special_tokens: Optional[str] = field( default=None, metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, ) - new_normal_tokens: Optional[str] = field( - default=None, - metadata={"help": "Normal tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, - ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, @@ -180,11 +182,11 @@ class BaseModelArguments: if self.adapter_name_or_path is not None: # support merging multiple lora weights self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] - if self.new_normal_tokens is not None: # support multiple normal tokens - self.new_normal_tokens = [token.strip() for token in self.new_normal_tokens.split(",")] + if self.add_tokens is not None: # support multiple tokens + self.add_tokens = [token.strip() for token in self.add_tokens.split(",")] - if self.new_special_tokens is not None: # support multiple special tokens - self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] + if self.add_special_tokens is not None: # support multiple special tokens + self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")] @dataclass diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 7209e216..cfe71498 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -153,7 +153,7 @@ def _check_extra_dependencies( elif model_args.infer_backend == EngineName.SGLANG: check_version("sglang>=0.4.4") check_version("sglang", mandatory=True) - + if finetuning_args.use_galore: check_version("galore_torch", mandatory=True) diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 8d888ccb..cb288af1 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -124,6 +124,7 @@ def configure_quantization( try: from optimum.gptq import utils as gq_utils + if "language_model.model.layers" not in gq_utils.BLOCK_PATTERNS: gq_utils.BLOCK_PATTERNS.insert(0, "language_model.model.layers") except ImportError: diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 52b25110..e0418bc1 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -54,26 +54,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument if model_args.model_max_length is not None and tokenizer.model_max_length < model_args.model_max_length: tokenizer.model_max_length = model_args.model_max_length # enlarge the tokenizer max length - if model_args.new_special_tokens is not None: - num_added_special_tokens = tokenizer.add_special_tokens( - dict(additional_special_tokens=model_args.new_special_tokens), - replace_additional_special_tokens=False, + if model_args.add_tokens is not None: + num_added_tokens = tokenizer.add_tokens(new_tokens=model_args.add_tokens, special_tokens=False) + logger.info_rank0("Add tokens {} to tokenizer's vocabulary.".format(",".join(model_args.add_tokens))) + if num_added_tokens > 0 and not model_args.resize_vocab: + model_args.resize_vocab = True + logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.") + + if model_args.add_special_tokens is not None: + num_added_special_tokens = tokenizer.add_tokens(new_tokens=model_args.add_special_tokens, special_tokens=True) + logger.info_rank0( + "Add special tokens {} to tokenizer's vocabulary.".format(",".join(model_args.add_special_tokens)) ) - logger.info_rank0("Add special tokens {} to vocab.".format(",".join(model_args.new_special_tokens))) if num_added_special_tokens > 0 and not model_args.resize_vocab: model_args.resize_vocab = True logger.warning_rank0("New special tokens have been added, changed `resize_vocab` to True.") - if model_args.new_normal_tokens is not None: - num_added_normal_tokens = tokenizer.add_tokens( - new_tokens=model_args.new_normal_tokens, - special_tokens=False, - ) - logger.info_rank0("Add normal tokens {} to vocab.".format(",".join(model_args.new_normal_tokens))) - if num_added_normal_tokens > 0 and not model_args.resize_vocab: - model_args.resize_vocab = True - logger.warning_rank0("New normal tokens have been added, changed `resize_vocab` to True.") - def patch_processor( processor: "ProcessorMixin", diff --git a/src/llamafactory/third_party/__init__.py b/src/llamafactory/third_party/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llamafactory/third_party/muon/muon.py b/src/llamafactory/third_party/muon/muon.py index e50097f5..d7482c36 100644 --- a/src/llamafactory/third_party/muon/muon.py +++ b/src/llamafactory/third_party/muon/muon.py @@ -2,6 +2,8 @@ # # This code is based on the MoonshotAI's Moonlight library. # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py +# and the Keller Jordan's Muon library. +# https://github.com/KellerJordan/Muon/blob/master/muon.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +20,7 @@ # MIT License # # Copyright (c) 2025 Moonshot AI +# Copyright (c) 2024 Keller Jordan # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -36,22 +39,20 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. + import math import torch -# This code snippet is a modified version adapted from the following GitHub repository: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -@torch.compile -def zeropower_via_newtonschulz5(G, steps): +def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor": """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. - For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing + the slope at zero even beyond the point where the iteration no longer converges all the way to + one everywhere on the interval. This iteration therefore does not produce UV^T but rather something + like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model performance at all relative to UV^T, where USV^T = G is the SVD. """ assert len(G.shape) == 2 @@ -133,7 +134,7 @@ class Muon(torch.optim.Optimizer): # Do not use Muon for parameters in adamw_params self.state[p]["use_muon"] = False - def adjust_lr_for_muon(self, lr, param_shape): + def adjust_lr_for_muon(self, lr: float, param_shape: list[int]) -> float: A, B = param_shape[:2] # We adjust the learning rate and weight decay based on the size of the parameter matrix # as describted in the paper @@ -154,12 +155,8 @@ class Muon(torch.optim.Optimizer): loss = closure() for group in self.param_groups: - ############################ - # Muon # - ############################ - + # Muon loop params = [p for p in group["params"] if self.state[p]["use_muon"]] - # import pdb; pdb.set_trace() lr = group["lr"] wd = group["wd"] momentum = group["momentum"] @@ -195,10 +192,7 @@ class Muon(torch.optim.Optimizer): # apply update p.data.add_(u, alpha=-adjusted_lr) - ############################ - # AdamW backup # - ############################ - + # Adam backup params = [p for p in group["params"] if not self.state[p]["use_muon"]] lr = group["lr"] beta1, beta2 = group["adamw_betas"] diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 34b417a3..12ee6bb3 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -489,16 +489,14 @@ def _create_adam_mini_optimizer( logger.info_rank0("Using Adam-mini optimizer.") return optimizer + def _create_muon_optimizer( model: "PreTrainedModel", training_args: "TrainingArguments", ) -> "torch.optim.Optimizer": - from llamafactory.third_party.muon import Muon # type: ignore - - # Separate parameters for Muon (2D parameters) and AdamW (others) - muon_params = [] - adamw_params = [] - + from ..third_party.muon import Muon + + muon_params, adamw_params = [], [] for name, param in model.named_parameters(): if param.requires_grad: # Use Muon for 2D parameters that aren't embeddings or heads @@ -506,34 +504,26 @@ def _create_muon_optimizer( muon_params.append(param) else: adamw_params.append(param) - - # Get optimizer settings from training_args - ns_steps = getattr(training_args, "ns_steps", 5) - - # Create Muon optimizer + optimizer = Muon( lr=training_args.learning_rate, wd=training_args.weight_decay, muon_params=muon_params, - momentum=0.95, # default momentum for Muon - nesterov=True, # default nesterov for Muon - ns_steps=ns_steps, adamw_params=adamw_params, adamw_betas=(training_args.adam_beta1, training_args.adam_beta2), adamw_eps=training_args.adam_epsilon, ) - - logger.info_rank0(f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params.") + logger.info_rank0( + f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params." + ) return optimizer + def create_custom_optimizer( model: "PreTrainedModel", training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> Optional["torch.optim.Optimizer"]: - if finetuning_args.use_muon: - return _create_muon_optimizer(model, training_args) - if finetuning_args.use_galore: return _create_galore_optimizer(model, training_args, finetuning_args) @@ -549,6 +539,9 @@ def create_custom_optimizer( if finetuning_args.use_adam_mini: return _create_adam_mini_optimizer(model, training_args) + if finetuning_args.use_muon: + return _create_muon_optimizer(model, training_args) + def create_custom_scheduler( training_args: "TrainingArguments", diff --git a/tests/model/model_utils/test_add_tokens.py b/tests/model/model_utils/test_add_tokens.py new file mode 100644 index 00000000..cb1c414a --- /dev/null +++ b/tests/model/model_utils/test_add_tokens.py @@ -0,0 +1,46 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from llamafactory.hparams import ModelArguments +from llamafactory.model import load_tokenizer + + +TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3") + +UNUSED_TOKEN = "<|UNUSED_TOKEN|>" + + +@pytest.mark.parametrize("special_tokens", [False, True]) +def test_add_tokens(special_tokens: bool): + if special_tokens: + model_args = ModelArguments(model_name_or_path=TINY_LLAMA3, add_special_tokens=UNUSED_TOKEN) + else: + model_args = ModelArguments(model_name_or_path=TINY_LLAMA3, add_tokens=UNUSED_TOKEN) + + tokenizer = load_tokenizer(model_args)["tokenizer"] + encoded_ids = tokenizer.encode(UNUSED_TOKEN, add_special_tokens=False) + assert len(encoded_ids) == 1 + decoded_str = tokenizer.decode(encoded_ids, skip_special_tokens=True) + if special_tokens: + assert decoded_str == "" + else: + assert decoded_str == UNUSED_TOKEN + + +if __name__ == "__main__": + pytest.main([__file__])