From d128382d3c7719f771ff13b8aac4b4ffba84f26f Mon Sep 17 00:00:00 2001 From: Juanxi Tian <103416111+tianshijing@users.noreply.github.com> Date: Mon, 21 Apr 2025 23:38:37 +0800 Subject: [PATCH] [trainer] Add Muon Optimizer (#7749) Co-authored-by: hoshi-hiyouga --- README.md | 23 +- README_zh.md | 23 +- examples/README.md | 6 + examples/README_zh.md | 6 + examples/extras/muon/qwen2_full_sft.yaml | 43 ++++ src/llamafactory/hparams/finetuning_args.py | 4 + src/llamafactory/hparams/parser.py | 2 +- src/llamafactory/third_party/muon/__init__.py | 18 ++ src/llamafactory/third_party/muon/muon.py | 232 ++++++++++++++++++ src/llamafactory/train/trainer_utils.py | 39 +++ 10 files changed, 371 insertions(+), 25 deletions(-) create mode 100644 examples/extras/muon/qwen2_full_sft.yaml create mode 100644 src/llamafactory/third_party/muon/__init__.py create mode 100644 src/llamafactory/third_party/muon/muon.py diff --git a/README.md b/README.md index e2abe336..296c6f79 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ Choose your path: - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. -- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. +- **Advanced algorithms**: [Muon](https://github.com/KellerJordan/Muon), [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA. - **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc. @@ -107,7 +107,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog -[25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started. +[25/04/16] We supported **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [Juanxi Tian](https://tianshijing.github.io)'s PR. [25/04/14] We supported fine-tuning the **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** and **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** models. @@ -115,14 +115,14 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started. -
Full Changelog - [25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference. [25/03/12] We supported fine-tuning the **[Gemma 3](https://huggingface.co/blog/gemma3)** model. [25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training. +
Full Changelog + [25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage. [25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks. @@ -245,11 +245,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) | | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | -| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | +| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | -| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)\*\* | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | +| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | @@ -417,11 +417,11 @@ huggingface-cli login | Mandatory | Minimum | Recommend | | ------------ | ------- | --------- | | python | 3.9 | 3.10 | -| torch | 2.0.0 | 2.6.0 | -| transformers | 4.45.0 | 4.50.0 | +| torch | 1.13.1 | 2.6.0 | +| transformers | 4.41.2 | 4.50.0 | | datasets | 2.16.0 | 3.2.0 | | accelerate | 0.34.0 | 1.2.1 | -| peft | 0.14.0 | 0.15.1 | +| peft | 0.14.0 | 0.15.0 | | trl | 0.8.6 | 0.9.6 | | Optional | Minimum | Recommend | @@ -430,7 +430,7 @@ huggingface-cli login | deepspeed | 0.10.0 | 0.16.4 | | bitsandbytes | 0.39.0 | 0.43.1 | | vllm | 0.4.3 | 0.8.2 | -| flash-attn | 2.5.6 | 2.7.2 | +| flash-attn | 2.3.0 | 2.7.2 | ### Hardware Requirement @@ -458,7 +458,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality +Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, sglang, muon, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality > [!TIP] > Use `pip install --no-deps -e .` to resolve package conflicts. @@ -519,7 +519,6 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh | torch | 2.1.0 | 2.4.0 | | torch-npu | 2.1.0 | 2.4.0.post2 | | deepspeed | 0.13.2 | 0.13.2 | -| vllm-ascend | - | 0.7.3 | Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. diff --git a/README_zh.md b/README_zh.md index d052d1f6..61576964 100644 --- a/README_zh.md +++ b/README_zh.md @@ -80,7 +80,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc - **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 -- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。 +- **先进算法**:[Muon](https://github.com/KellerJordan/Muon), [GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。 - **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。 - **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、[SwanLab](https://github.com/SwanHubX/SwanLab) 等等。 @@ -110,7 +110,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ## 更新日志 -[25/04/16] 我们支持了 **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** 模型的微调。查看 [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) 以使用。 +[25/04/16] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@Juanxi Tian](https://tianshijing.github.io) 的 PR。 [25/04/14] 我们支持了 **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** 和 **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** 模型的微调。 @@ -118,14 +118,14 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc [25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。 -
展开日志 - [25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。 [25/03/12] 我们支持了 **[Gemma 3](https://huggingface.co/blog/gemma3)** 模型的微调。 [25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。 +
展开日志 + [25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。 [25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。 @@ -248,11 +248,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) | | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | -| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | +| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | -| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)\*\* | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | +| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | @@ -420,11 +420,11 @@ huggingface-cli login | 必需项 | 至少 | 推荐 | | ------------ | ------- | --------- | | python | 3.9 | 3.10 | -| torch | 2.0.0 | 2.6.0 | -| transformers | 4.45.0 | 4.50.0 | +| torch | 1.13.1 | 2.6.0 | +| transformers | 4.41.2 | 4.50.0 | | datasets | 2.16.0 | 3.2.0 | | accelerate | 0.34.0 | 1.2.1 | -| peft | 0.14.0 | 0.15.1 | +| peft | 0.14.0 | 0.15.0 | | trl | 0.8.6 | 0.9.6 | | 可选项 | 至少 | 推荐 | @@ -433,7 +433,7 @@ huggingface-cli login | deepspeed | 0.10.0 | 0.16.4 | | bitsandbytes | 0.39.0 | 0.43.1 | | vllm | 0.4.3 | 0.8.2 | -| flash-attn | 2.5.6 | 2.7.2 | +| flash-attn | 2.3.0 | 2.7.2 | ### 硬件依赖 @@ -461,7 +461,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality +可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、sglang、muon, galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality > [!TIP] > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 @@ -523,7 +523,6 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh | torch | 2.1.0 | 2.4.0 | | torch-npu | 2.1.0 | 2.4.0.post2 | | deepspeed | 0.13.2 | 0.13.2 | -| vllm-ascend | - | 0.7.3 | 请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。 diff --git a/examples/README.md b/examples/README.md index f421d47d..a59045a4 100644 --- a/examples/README.md +++ b/examples/README.md @@ -215,6 +215,12 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml ### Extras +#### Full-Parameter Fine-Tuning using Muon + +```bash +llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml +``` + #### Full-Parameter Fine-Tuning using GaLore ```bash diff --git a/examples/README_zh.md b/examples/README_zh.md index 9ff0e994..b3b93d9d 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -215,6 +215,12 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml ### 杂项 +#### 使用 Muon 进行全参数训练 + +```bash +llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml +``` + #### 使用 GaLore 进行全参数训练 ```bash diff --git a/examples/extras/muon/qwen2_full_sft.yaml b/examples/extras/muon/qwen2_full_sft.yaml new file mode 100644 index 00000000..4380846a --- /dev/null +++ b/examples/extras/muon/qwen2_full_sft.yaml @@ -0,0 +1,43 @@ +### model +model_name_or_path: Qwen/Qwen2-1.5B-Instruct +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: full +use_muon: true + +### dataset +dataset: identity,alpaca_en_demo +template: qwen +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/qwen2-1_5b/full/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-5 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +# val_size: 0.1 +# per_device_eval_batch_size: 1 +# eval_strategy: steps +# eval_steps: 500 diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 483e9f4a..79e69e7b 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -411,6 +411,10 @@ class FinetuningArguments( default=False, metadata={"help": "Whether or not to use the Adam-mini optimizer."}, ) + use_muon: bool = field( + default=False, + metadata={"help": "Whether or not to use the Muon optimizer."}, + ) freeze_vision_tower: bool = field( default=True, metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."}, diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index cfe71498..7209e216 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -153,7 +153,7 @@ def _check_extra_dependencies( elif model_args.infer_backend == EngineName.SGLANG: check_version("sglang>=0.4.4") check_version("sglang", mandatory=True) - + if finetuning_args.use_galore: check_version("galore_torch", mandatory=True) diff --git a/src/llamafactory/third_party/muon/__init__.py b/src/llamafactory/third_party/muon/__init__.py new file mode 100644 index 00000000..afa615d0 --- /dev/null +++ b/src/llamafactory/third_party/muon/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .muon import Muon + + +__all__ = ["Muon"] diff --git a/src/llamafactory/third_party/muon/muon.py b/src/llamafactory/third_party/muon/muon.py new file mode 100644 index 00000000..e50097f5 --- /dev/null +++ b/src/llamafactory/third_party/muon/muon.py @@ -0,0 +1,232 @@ +# Copyright 2025 Moonshot AI and the LlamaFactory team. +# +# This code is based on the MoonshotAI's Moonlight library. +# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2025 Moonshot AI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math + +import torch + + +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +@torch.compile +def zeropower_via_newtonschulz5(G, steps): + """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. + + We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. + For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + """Muon - MomentUm Orthogonalized by Newton-schulz. + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + lr=1e-3, + wd=0.1, + muon_params=None, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_params=None, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + ): + defaults = dict( + lr=lr, + wd=wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + # Sort parameters into those for which we will use Muon, and those for which we will not + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + # import pdb; pdb.set_trace() + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + # generate weight updates in distributed fashion + for p in params: + # sanity check + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 89459a82..34b417a3 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -489,12 +489,51 @@ def _create_adam_mini_optimizer( logger.info_rank0("Using Adam-mini optimizer.") return optimizer +def _create_muon_optimizer( + model: "PreTrainedModel", + training_args: "TrainingArguments", +) -> "torch.optim.Optimizer": + from llamafactory.third_party.muon import Muon # type: ignore + + # Separate parameters for Muon (2D parameters) and AdamW (others) + muon_params = [] + adamw_params = [] + + for name, param in model.named_parameters(): + if param.requires_grad: + # Use Muon for 2D parameters that aren't embeddings or heads + if param.ndim == 2 and "embed" not in name and "lm_head" not in name: + muon_params.append(param) + else: + adamw_params.append(param) + + # Get optimizer settings from training_args + ns_steps = getattr(training_args, "ns_steps", 5) + + # Create Muon optimizer + optimizer = Muon( + lr=training_args.learning_rate, + wd=training_args.weight_decay, + muon_params=muon_params, + momentum=0.95, # default momentum for Muon + nesterov=True, # default nesterov for Muon + ns_steps=ns_steps, + adamw_params=adamw_params, + adamw_betas=(training_args.adam_beta1, training_args.adam_beta2), + adamw_eps=training_args.adam_epsilon, + ) + + logger.info_rank0(f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params.") + return optimizer def create_custom_optimizer( model: "PreTrainedModel", training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> Optional["torch.optim.Optimizer"]: + if finetuning_args.use_muon: + return _create_muon_optimizer(model, training_args) + if finetuning_args.use_galore: return _create_galore_optimizer(model, training_args, finetuning_args)