mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[example] add bash usage (#7794)
This commit is contained in:
		
							parent
							
								
									12ada72ed4
								
							
						
					
					
						commit
						b07628dea5
					
				
							
								
								
									
										25
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								README.md
									
									
									
									
									
								
							@ -77,7 +77,7 @@ Choose your path:
 | 
			
		||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
 | 
			
		||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
 | 
			
		||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
 | 
			
		||||
- **Advanced algorithms**: [Muon](https://github.com/KellerJordan/Muon), [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
 | 
			
		||||
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
 | 
			
		||||
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
 | 
			
		||||
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
 | 
			
		||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
 | 
			
		||||
@ -107,7 +107,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
## Changelog
 | 
			
		||||
 | 
			
		||||
[25/04/16] We supported **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [Juanxi Tian](https://tianshijing.github.io)'s PR.
 | 
			
		||||
[25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR.
 | 
			
		||||
 | 
			
		||||
[25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started.
 | 
			
		||||
 | 
			
		||||
[25/04/14] We supported fine-tuning the **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** and **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** models.
 | 
			
		||||
 | 
			
		||||
@ -115,14 +117,14 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started.
 | 
			
		||||
 | 
			
		||||
<details><summary>Full Changelog</summary>
 | 
			
		||||
 | 
			
		||||
[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference.
 | 
			
		||||
 | 
			
		||||
[25/03/12] We supported fine-tuning the **[Gemma 3](https://huggingface.co/blog/gemma3)** model.
 | 
			
		||||
 | 
			
		||||
[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
 | 
			
		||||
 | 
			
		||||
<details><summary>Full Changelog</summary>
 | 
			
		||||
 | 
			
		||||
[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
 | 
			
		||||
 | 
			
		||||
[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
 | 
			
		||||
@ -245,11 +247,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
| [Gemma 3](https://huggingface.co/google)                          | 1B/4B/12B/27B                    | gemma3/gemma (1B)   |
 | 
			
		||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM)           | 9B/32B                           | glm4                |
 | 
			
		||||
| [GPT-2](https://huggingface.co/openai-community)                  | 0.1B/0.4B/0.8B/1.5B              | -                   |
 | 
			
		||||
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite)             | 1B/2B/3B/8B                      | granite3            |
 | 
			
		||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite)             | 1B/2B/3B/8B                      | granite3            |
 | 
			
		||||
| [Hunyuan](https://huggingface.co/tencent/)                        | 7B                               | hunyuan             |
 | 
			
		||||
| [Index](https://huggingface.co/IndexTeam)                         | 1.9B                             | index               |
 | 
			
		||||
| [InternLM 2-3](https://huggingface.co/internlm)                   | 7B/8B/20B                        | intern2             |
 | 
			
		||||
| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL)        | 1B/2B/4B/8B/9B/14B/26B/38B/78B   | intern_vl           |
 | 
			
		||||
| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)\*\*            | 1B/2B/4B/8B/9B/14B/26B/38B/78B   | intern_vl           |
 | 
			
		||||
| [Kimi-VL](https://huggingface.co/moonshotai)                      | 16B                              | kimi_vl             |
 | 
			
		||||
| [Llama](https://github.com/facebookresearch/llama)                | 7B/13B/33B/65B                   | -                   |
 | 
			
		||||
| [Llama 2](https://huggingface.co/meta-llama)                      | 7B/13B/70B                       | llama2              |
 | 
			
		||||
@ -417,11 +419,11 @@ huggingface-cli login
 | 
			
		||||
| Mandatory    | Minimum | Recommend |
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
| python       | 3.9     | 3.10      |
 | 
			
		||||
| torch        | 1.13.1  | 2.6.0     |
 | 
			
		||||
| transformers | 4.41.2  | 4.50.0    |
 | 
			
		||||
| torch        | 2.0.0   | 2.6.0     |
 | 
			
		||||
| transformers | 4.45.0  | 4.50.0    |
 | 
			
		||||
| datasets     | 2.16.0  | 3.2.0     |
 | 
			
		||||
| accelerate   | 0.34.0  | 1.2.1     |
 | 
			
		||||
| peft         | 0.14.0  | 0.15.0    |
 | 
			
		||||
| peft         | 0.14.0  | 0.15.1    |
 | 
			
		||||
| trl          | 0.8.6   | 0.9.6     |
 | 
			
		||||
 | 
			
		||||
| Optional     | Minimum | Recommend |
 | 
			
		||||
@ -430,7 +432,7 @@ huggingface-cli login
 | 
			
		||||
| deepspeed    | 0.10.0  | 0.16.4    |
 | 
			
		||||
| bitsandbytes | 0.39.0  | 0.43.1    |
 | 
			
		||||
| vllm         | 0.4.3   | 0.8.2     |
 | 
			
		||||
| flash-attn   | 2.3.0   | 2.7.2     |
 | 
			
		||||
| flash-attn   | 2.5.6   | 2.7.2     |
 | 
			
		||||
 | 
			
		||||
### Hardware Requirement
 | 
			
		||||
 | 
			
		||||
@ -458,7 +460,7 @@ cd LLaMA-Factory
 | 
			
		||||
pip install -e ".[torch,metrics]"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, sglang, muon, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality
 | 
			
		||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality
 | 
			
		||||
 | 
			
		||||
> [!TIP]
 | 
			
		||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
 | 
			
		||||
@ -519,6 +521,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
 | 
			
		||||
| torch        | 2.1.0   | 2.4.0          |
 | 
			
		||||
| torch-npu    | 2.1.0   | 2.4.0.post2    |
 | 
			
		||||
| deepspeed    | 0.13.2  | 0.13.2         |
 | 
			
		||||
| vllm-ascend  | -       | 0.7.3          |
 | 
			
		||||
 | 
			
		||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										25
									
								
								README_zh.md
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								README_zh.md
									
									
									
									
									
								
							@ -80,7 +80,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。
 | 
			
		||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
 | 
			
		||||
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
 | 
			
		||||
- **先进算法**:[Muon](https://github.com/KellerJordan/Muon), [GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
 | 
			
		||||
- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
 | 
			
		||||
- **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
 | 
			
		||||
- **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。
 | 
			
		||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、[SwanLab](https://github.com/SwanHubX/SwanLab) 等等。
 | 
			
		||||
@ -110,7 +110,9 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
 | 
			
		||||
## 更新日志
 | 
			
		||||
 | 
			
		||||
[25/04/16] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@Juanxi Tian](https://tianshijing.github.io) 的 PR。
 | 
			
		||||
[25/04/21] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@tianshijing](https://github.com/tianshijing) 的 PR。
 | 
			
		||||
 | 
			
		||||
[25/04/16] 我们支持了 **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** 模型的微调。查看 [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) 以使用。
 | 
			
		||||
 | 
			
		||||
[25/04/14] 我们支持了 **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** 和 **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** 模型的微调。
 | 
			
		||||
 | 
			
		||||
@ -118,14 +120,14 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
 | 
			
		||||
[25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。
 | 
			
		||||
 | 
			
		||||
<details><summary>展开日志</summary>
 | 
			
		||||
 | 
			
		||||
[25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。
 | 
			
		||||
 | 
			
		||||
[25/03/12] 我们支持了 **[Gemma 3](https://huggingface.co/blog/gemma3)** 模型的微调。
 | 
			
		||||
 | 
			
		||||
[25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。
 | 
			
		||||
 | 
			
		||||
<details><summary>展开日志</summary>
 | 
			
		||||
 | 
			
		||||
[25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。
 | 
			
		||||
 | 
			
		||||
[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。
 | 
			
		||||
@ -248,11 +250,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
| [Gemma 3](https://huggingface.co/google)                          | 1B/4B/12B/27B                    | gemma3/gemma (1B)   |
 | 
			
		||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM)           | 9B/32B                           | glm4                |
 | 
			
		||||
| [GPT-2](https://huggingface.co/openai-community)                  | 0.1B/0.4B/0.8B/1.5B              | -                   |
 | 
			
		||||
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite)             | 1B/2B/3B/8B                      | granite3            |
 | 
			
		||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite)             | 1B/2B/3B/8B                      | granite3            |
 | 
			
		||||
| [Hunyuan](https://huggingface.co/tencent/)                        | 7B                               | hunyuan             |
 | 
			
		||||
| [Index](https://huggingface.co/IndexTeam)                         | 1.9B                             | index               |
 | 
			
		||||
| [InternLM 2-3](https://huggingface.co/internlm)                   | 7B/8B/20B                        | intern2             |
 | 
			
		||||
| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL)        | 1B/2B/4B/8B/9B/14B/26B/38B/78B   | intern_vl           |
 | 
			
		||||
| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)\*\*            | 1B/2B/4B/8B/9B/14B/26B/38B/78B   | intern_vl           |
 | 
			
		||||
| [Kimi-VL](https://huggingface.co/moonshotai)                      | 16B                              | kimi_vl             |
 | 
			
		||||
| [Llama](https://github.com/facebookresearch/llama)                | 7B/13B/33B/65B                   | -                   |
 | 
			
		||||
| [Llama 2](https://huggingface.co/meta-llama)                      | 7B/13B/70B                       | llama2              |
 | 
			
		||||
@ -420,11 +422,11 @@ huggingface-cli login
 | 
			
		||||
| 必需项        | 至少     | 推荐      |
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
| python       | 3.9     | 3.10      |
 | 
			
		||||
| torch        | 1.13.1  | 2.6.0     |
 | 
			
		||||
| transformers | 4.41.2  | 4.50.0    |
 | 
			
		||||
| torch        | 2.0.0   | 2.6.0     |
 | 
			
		||||
| transformers | 4.45.0  | 4.50.0    |
 | 
			
		||||
| datasets     | 2.16.0  | 3.2.0     |
 | 
			
		||||
| accelerate   | 0.34.0  | 1.2.1     |
 | 
			
		||||
| peft         | 0.14.0  | 0.15.0    |
 | 
			
		||||
| peft         | 0.14.0  | 0.15.1    |
 | 
			
		||||
| trl          | 0.8.6   | 0.9.6     |
 | 
			
		||||
 | 
			
		||||
| 可选项        | 至少     | 推荐      |
 | 
			
		||||
@ -433,7 +435,7 @@ huggingface-cli login
 | 
			
		||||
| deepspeed    | 0.10.0  | 0.16.4    |
 | 
			
		||||
| bitsandbytes | 0.39.0  | 0.43.1    |
 | 
			
		||||
| vllm         | 0.4.3   | 0.8.2     |
 | 
			
		||||
| flash-attn   | 2.3.0   | 2.7.2     |
 | 
			
		||||
| flash-attn   | 2.5.6   | 2.7.2     |
 | 
			
		||||
 | 
			
		||||
### 硬件依赖
 | 
			
		||||
 | 
			
		||||
@ -461,7 +463,7 @@ cd LLaMA-Factory
 | 
			
		||||
pip install -e ".[torch,metrics]"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、sglang、muon, galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality
 | 
			
		||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality
 | 
			
		||||
 | 
			
		||||
> [!TIP]
 | 
			
		||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
 | 
			
		||||
@ -523,6 +525,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
 | 
			
		||||
| torch        | 2.1.0   | 2.4.0          |
 | 
			
		||||
| torch-npu    | 2.1.0   | 2.4.0.post2    |
 | 
			
		||||
| deepspeed    | 0.13.2  | 0.13.2         |
 | 
			
		||||
| vllm-ascend  | -       | 0.7.3          |
 | 
			
		||||
 | 
			
		||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,13 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
 | 
			
		||||
Advanced usage:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml learning_rate=1e-5 logging_steps=1
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \
 | 
			
		||||
    learning_rate=1e-5 \
 | 
			
		||||
    logging_steps=1
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
bash examples/train_lora/llama3_lora_sft.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Examples
 | 
			
		||||
@ -215,12 +221,6 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml
 | 
			
		||||
 | 
			
		||||
### Extras
 | 
			
		||||
 | 
			
		||||
#### Full-Parameter Fine-Tuning using Muon
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Full-Parameter Fine-Tuning using GaLore
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
@ -245,6 +245,12 @@ llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
 | 
			
		||||
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Full-Parameter Fine-Tuning using Muon
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### LoRA+ Fine-Tuning
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,13 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
 | 
			
		||||
高级用法:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml learning_rate=1e-5 logging_steps=1
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \
 | 
			
		||||
    learning_rate=1e-5 \
 | 
			
		||||
    logging_steps=1
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
bash examples/train_lora/llama3_lora_sft.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## 示例
 | 
			
		||||
@ -215,12 +221,6 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml
 | 
			
		||||
 | 
			
		||||
### 杂项
 | 
			
		||||
 | 
			
		||||
#### 使用 Muon 进行全参数训练
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 使用 GaLore 进行全参数训练
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
@ -245,6 +245,12 @@ llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
 | 
			
		||||
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 使用 Muon 进行全参数训练
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### LoRA+ 微调
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										36
									
								
								examples/train_lora/llama3_lora_sft.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								examples/train_lora/llama3_lora_sft.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,36 @@
 | 
			
		||||
#!/bin/bash
 | 
			
		||||
 | 
			
		||||
set -x
 | 
			
		||||
 | 
			
		||||
MODEL_PATH=meta-llama/Meta-Llama-3-8B-Instruct
 | 
			
		||||
 | 
			
		||||
llamafactory-cli train \
 | 
			
		||||
    --model_name_or_path ${MODEL_PATH} \
 | 
			
		||||
    --trust_remote_code \
 | 
			
		||||
    --stage sft \
 | 
			
		||||
    --do_train \
 | 
			
		||||
    --finetuning_type lora \
 | 
			
		||||
    --lora_rank 8 \
 | 
			
		||||
    --lora_target all \
 | 
			
		||||
    --dataset identity,alpaca_en_demo \
 | 
			
		||||
    --template llama3 \
 | 
			
		||||
    --cutoff_len 2048 \
 | 
			
		||||
    --max_samples 1000 \
 | 
			
		||||
    --overwrite_cache \
 | 
			
		||||
    --preprocessing_num_workers 16 \
 | 
			
		||||
    --dataloader_num_workers 4 \
 | 
			
		||||
    --output_dir saves/llama3-8b/lora/sft \
 | 
			
		||||
    --logging_steps 10 \
 | 
			
		||||
    --save_steps 500 \
 | 
			
		||||
    --plot_loss \
 | 
			
		||||
    --overwrite_output_dir \
 | 
			
		||||
    --save_only_model false \
 | 
			
		||||
    --report_to none \
 | 
			
		||||
    --per_device_train_batch_size 1 \
 | 
			
		||||
    --gradient_accumulation_steps 8 \
 | 
			
		||||
    --learning_rate 1e-4 \
 | 
			
		||||
    --num_train_epochs 3.0 \
 | 
			
		||||
    --lr_scheduler_type cosine \
 | 
			
		||||
    --warmup_ratio 0.1 \
 | 
			
		||||
    --bf16 \
 | 
			
		||||
    --ddp_timeout 180000000
 | 
			
		||||
@ -65,14 +65,16 @@ class BaseModelArguments:
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
 | 
			
		||||
    )
 | 
			
		||||
    new_special_tokens: Optional[str] = field(
 | 
			
		||||
    add_tokens: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    add_special_tokens: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
 | 
			
		||||
    )
 | 
			
		||||
    new_normal_tokens: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Normal tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
 | 
			
		||||
    )
 | 
			
		||||
    model_revision: str = field(
 | 
			
		||||
        default="main",
 | 
			
		||||
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
 | 
			
		||||
@ -180,11 +182,11 @@ class BaseModelArguments:
 | 
			
		||||
        if self.adapter_name_or_path is not None:  # support merging multiple lora weights
 | 
			
		||||
            self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
 | 
			
		||||
 | 
			
		||||
        if self.new_normal_tokens is not None:  # support multiple normal tokens
 | 
			
		||||
            self.new_normal_tokens = [token.strip() for token in self.new_normal_tokens.split(",")]
 | 
			
		||||
        if self.add_tokens is not None:  # support multiple tokens
 | 
			
		||||
            self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
 | 
			
		||||
 | 
			
		||||
        if self.new_special_tokens is not None:  # support multiple special tokens
 | 
			
		||||
            self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
 | 
			
		||||
        if self.add_special_tokens is not None:  # support multiple special tokens
 | 
			
		||||
            self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
 | 
			
		||||
@ -153,7 +153,7 @@ def _check_extra_dependencies(
 | 
			
		||||
    elif model_args.infer_backend == EngineName.SGLANG:
 | 
			
		||||
        check_version("sglang>=0.4.4")
 | 
			
		||||
        check_version("sglang", mandatory=True)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_galore:
 | 
			
		||||
        check_version("galore_torch", mandatory=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -124,6 +124,7 @@ def configure_quantization(
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            from optimum.gptq import utils as gq_utils
 | 
			
		||||
 | 
			
		||||
            if "language_model.model.layers" not in gq_utils.BLOCK_PATTERNS:
 | 
			
		||||
                gq_utils.BLOCK_PATTERNS.insert(0, "language_model.model.layers")
 | 
			
		||||
        except ImportError:
 | 
			
		||||
 | 
			
		||||
@ -54,26 +54,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
 | 
			
		||||
    if model_args.model_max_length is not None and tokenizer.model_max_length < model_args.model_max_length:
 | 
			
		||||
        tokenizer.model_max_length = model_args.model_max_length  # enlarge the tokenizer max length
 | 
			
		||||
 | 
			
		||||
    if model_args.new_special_tokens is not None:
 | 
			
		||||
        num_added_special_tokens = tokenizer.add_special_tokens(
 | 
			
		||||
            dict(additional_special_tokens=model_args.new_special_tokens),
 | 
			
		||||
            replace_additional_special_tokens=False,
 | 
			
		||||
    if model_args.add_tokens is not None:
 | 
			
		||||
        num_added_tokens = tokenizer.add_tokens(new_tokens=model_args.add_tokens, special_tokens=False)
 | 
			
		||||
        logger.info_rank0("Add tokens {} to tokenizer's vocabulary.".format(",".join(model_args.add_tokens)))
 | 
			
		||||
        if num_added_tokens > 0 and not model_args.resize_vocab:
 | 
			
		||||
            model_args.resize_vocab = True
 | 
			
		||||
            logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
 | 
			
		||||
 | 
			
		||||
    if model_args.add_special_tokens is not None:
 | 
			
		||||
        num_added_special_tokens = tokenizer.add_tokens(new_tokens=model_args.add_special_tokens, special_tokens=True)
 | 
			
		||||
        logger.info_rank0(
 | 
			
		||||
            "Add special tokens {} to tokenizer's vocabulary.".format(",".join(model_args.add_special_tokens))
 | 
			
		||||
        )
 | 
			
		||||
        logger.info_rank0("Add special tokens {} to vocab.".format(",".join(model_args.new_special_tokens)))
 | 
			
		||||
        if num_added_special_tokens > 0 and not model_args.resize_vocab:
 | 
			
		||||
            model_args.resize_vocab = True
 | 
			
		||||
            logger.warning_rank0("New special tokens have been added, changed `resize_vocab` to True.")
 | 
			
		||||
 | 
			
		||||
    if model_args.new_normal_tokens is not None:
 | 
			
		||||
        num_added_normal_tokens = tokenizer.add_tokens(
 | 
			
		||||
            new_tokens=model_args.new_normal_tokens,
 | 
			
		||||
            special_tokens=False,
 | 
			
		||||
        )
 | 
			
		||||
        logger.info_rank0("Add normal tokens {} to vocab.".format(",".join(model_args.new_normal_tokens)))
 | 
			
		||||
        if num_added_normal_tokens > 0 and not model_args.resize_vocab:
 | 
			
		||||
            model_args.resize_vocab = True
 | 
			
		||||
            logger.warning_rank0("New normal tokens have been added, changed `resize_vocab` to True.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_processor(
 | 
			
		||||
    processor: "ProcessorMixin",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										0
									
								
								src/llamafactory/third_party/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/llamafactory/third_party/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
								
								
									
										30
									
								
								src/llamafactory/third_party/muon/muon.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										30
									
								
								src/llamafactory/third_party/muon/muon.py
									
									
									
									
										vendored
									
									
								
							@ -2,6 +2,8 @@
 | 
			
		||||
#
 | 
			
		||||
# This code is based on the MoonshotAI's Moonlight library.
 | 
			
		||||
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
 | 
			
		||||
# and the Keller Jordan's Muon library.
 | 
			
		||||
# https://github.com/KellerJordan/Muon/blob/master/muon.py
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
@ -18,6 +20,7 @@
 | 
			
		||||
# MIT License
 | 
			
		||||
#
 | 
			
		||||
# Copyright (c) 2025 Moonshot AI
 | 
			
		||||
# Copyright (c) 2024 Keller Jordan
 | 
			
		||||
#
 | 
			
		||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
			
		||||
# of this software and associated documentation files (the "Software"), to deal
 | 
			
		||||
@ -36,22 +39,20 @@
 | 
			
		||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
			
		||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
			
		||||
# SOFTWARE.
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# This code snippet is a modified version adapted from the following GitHub repository:
 | 
			
		||||
# https://github.com/KellerJordan/Muon/blob/master/muon.py
 | 
			
		||||
@torch.compile
 | 
			
		||||
def zeropower_via_newtonschulz5(G, steps):
 | 
			
		||||
def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor":
 | 
			
		||||
    """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
 | 
			
		||||
 | 
			
		||||
    We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero.
 | 
			
		||||
    For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
 | 
			
		||||
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
 | 
			
		||||
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
 | 
			
		||||
    where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
 | 
			
		||||
    For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing
 | 
			
		||||
    the slope at zero even beyond the point where the iteration no longer converges all the way to
 | 
			
		||||
    one everywhere on the interval. This iteration therefore does not produce UV^T but rather something
 | 
			
		||||
    like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
 | 
			
		||||
    performance at all relative to UV^T, where USV^T = G is the SVD.
 | 
			
		||||
    """
 | 
			
		||||
    assert len(G.shape) == 2
 | 
			
		||||
@ -133,7 +134,7 @@ class Muon(torch.optim.Optimizer):
 | 
			
		||||
            # Do not use Muon for parameters in adamw_params
 | 
			
		||||
            self.state[p]["use_muon"] = False
 | 
			
		||||
 | 
			
		||||
    def adjust_lr_for_muon(self, lr, param_shape):
 | 
			
		||||
    def adjust_lr_for_muon(self, lr: float, param_shape: list[int]) -> float:
 | 
			
		||||
        A, B = param_shape[:2]
 | 
			
		||||
        # We adjust the learning rate and weight decay based on the size of the parameter matrix
 | 
			
		||||
        # as describted in the paper
 | 
			
		||||
@ -154,12 +155,8 @@ class Muon(torch.optim.Optimizer):
 | 
			
		||||
                loss = closure()
 | 
			
		||||
 | 
			
		||||
        for group in self.param_groups:
 | 
			
		||||
            ############################
 | 
			
		||||
            #           Muon           #
 | 
			
		||||
            ############################
 | 
			
		||||
 | 
			
		||||
            # Muon loop
 | 
			
		||||
            params = [p for p in group["params"] if self.state[p]["use_muon"]]
 | 
			
		||||
            # import pdb; pdb.set_trace()
 | 
			
		||||
            lr = group["lr"]
 | 
			
		||||
            wd = group["wd"]
 | 
			
		||||
            momentum = group["momentum"]
 | 
			
		||||
@ -195,10 +192,7 @@ class Muon(torch.optim.Optimizer):
 | 
			
		||||
                # apply update
 | 
			
		||||
                p.data.add_(u, alpha=-adjusted_lr)
 | 
			
		||||
 | 
			
		||||
            ############################
 | 
			
		||||
            #       AdamW backup       #
 | 
			
		||||
            ############################
 | 
			
		||||
 | 
			
		||||
            # Adam backup
 | 
			
		||||
            params = [p for p in group["params"] if not self.state[p]["use_muon"]]
 | 
			
		||||
            lr = group["lr"]
 | 
			
		||||
            beta1, beta2 = group["adamw_betas"]
 | 
			
		||||
 | 
			
		||||
@ -489,16 +489,14 @@ def _create_adam_mini_optimizer(
 | 
			
		||||
    logger.info_rank0("Using Adam-mini optimizer.")
 | 
			
		||||
    return optimizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_muon_optimizer(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    training_args: "TrainingArguments",
 | 
			
		||||
) -> "torch.optim.Optimizer":
 | 
			
		||||
    from llamafactory.third_party.muon import Muon  # type: ignore
 | 
			
		||||
    
 | 
			
		||||
    # Separate parameters for Muon (2D parameters) and AdamW (others)
 | 
			
		||||
    muon_params = []
 | 
			
		||||
    adamw_params = []
 | 
			
		||||
    
 | 
			
		||||
    from ..third_party.muon import Muon
 | 
			
		||||
 | 
			
		||||
    muon_params, adamw_params = [], []
 | 
			
		||||
    for name, param in model.named_parameters():
 | 
			
		||||
        if param.requires_grad:
 | 
			
		||||
            # Use Muon for 2D parameters that aren't embeddings or heads
 | 
			
		||||
@ -506,34 +504,26 @@ def _create_muon_optimizer(
 | 
			
		||||
                muon_params.append(param)
 | 
			
		||||
            else:
 | 
			
		||||
                adamw_params.append(param)
 | 
			
		||||
    
 | 
			
		||||
    # Get optimizer settings from training_args
 | 
			
		||||
    ns_steps = getattr(training_args, "ns_steps", 5)
 | 
			
		||||
    
 | 
			
		||||
    # Create Muon optimizer
 | 
			
		||||
 | 
			
		||||
    optimizer = Muon(
 | 
			
		||||
        lr=training_args.learning_rate,
 | 
			
		||||
        wd=training_args.weight_decay,
 | 
			
		||||
        muon_params=muon_params,
 | 
			
		||||
        momentum=0.95,  # default momentum for Muon
 | 
			
		||||
        nesterov=True,  # default nesterov for Muon
 | 
			
		||||
        ns_steps=ns_steps,
 | 
			
		||||
        adamw_params=adamw_params,
 | 
			
		||||
        adamw_betas=(training_args.adam_beta1, training_args.adam_beta2),
 | 
			
		||||
        adamw_eps=training_args.adam_epsilon,
 | 
			
		||||
    )
 | 
			
		||||
    
 | 
			
		||||
    logger.info_rank0(f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params.")
 | 
			
		||||
    logger.info_rank0(
 | 
			
		||||
        f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params."
 | 
			
		||||
    )
 | 
			
		||||
    return optimizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_custom_optimizer(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    training_args: "TrainingArguments",
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
) -> Optional["torch.optim.Optimizer"]:
 | 
			
		||||
    if finetuning_args.use_muon:
 | 
			
		||||
        return _create_muon_optimizer(model, training_args)
 | 
			
		||||
    
 | 
			
		||||
    if finetuning_args.use_galore:
 | 
			
		||||
        return _create_galore_optimizer(model, training_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
@ -549,6 +539,9 @@ def create_custom_optimizer(
 | 
			
		||||
    if finetuning_args.use_adam_mini:
 | 
			
		||||
        return _create_adam_mini_optimizer(model, training_args)
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_muon:
 | 
			
		||||
        return _create_muon_optimizer(model, training_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_custom_scheduler(
 | 
			
		||||
    training_args: "TrainingArguments",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										46
									
								
								tests/model/model_utils/test_add_tokens.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								tests/model/model_utils/test_add_tokens.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,46 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from llamafactory.hparams import ModelArguments
 | 
			
		||||
from llamafactory.model import load_tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
 | 
			
		||||
 | 
			
		||||
UNUSED_TOKEN = "<|UNUSED_TOKEN|>"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("special_tokens", [False, True])
 | 
			
		||||
def test_add_tokens(special_tokens: bool):
 | 
			
		||||
    if special_tokens:
 | 
			
		||||
        model_args = ModelArguments(model_name_or_path=TINY_LLAMA3, add_special_tokens=UNUSED_TOKEN)
 | 
			
		||||
    else:
 | 
			
		||||
        model_args = ModelArguments(model_name_or_path=TINY_LLAMA3, add_tokens=UNUSED_TOKEN)
 | 
			
		||||
 | 
			
		||||
    tokenizer = load_tokenizer(model_args)["tokenizer"]
 | 
			
		||||
    encoded_ids = tokenizer.encode(UNUSED_TOKEN, add_special_tokens=False)
 | 
			
		||||
    assert len(encoded_ids) == 1
 | 
			
		||||
    decoded_str = tokenizer.decode(encoded_ids, skip_special_tokens=True)
 | 
			
		||||
    if special_tokens:
 | 
			
		||||
        assert decoded_str == ""
 | 
			
		||||
    else:
 | 
			
		||||
        assert decoded_str == UNUSED_TOKEN
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    pytest.main([__file__])
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user