From 9ef85f8fc401fc028d38b561efdf2bcc37ca4796 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 15 Jan 2025 01:42:50 +0800 Subject: [PATCH] [optim] clean apollo (#6645) * clean apollo code * update readme Former-commit-id: 7a04021d0461caea2c7b82169839340b7f51f463 --- README.md | 32 +++++++------- README_zh.md | 34 ++++++++------- examples/README.md | 6 +++ examples/README_zh.md | 6 +++ examples/extras/apollo/llama3_full_sft.yaml | 8 ++-- examples/extras/galore/llama3_full_sft.yaml | 6 +-- src/llamafactory/extras/packages.py | 2 + src/llamafactory/hparams/finetuning_args.py | 16 +++---- src/llamafactory/hparams/parser.py | 48 ++++++++------------- src/llamafactory/model/model_utils/misc.py | 2 +- src/llamafactory/train/trainer_utils.py | 26 ++++++----- src/llamafactory/webui/components/train.py | 9 ++-- src/llamafactory/webui/locales.py | 16 +++---- src/llamafactory/webui/runner.py | 2 +- 14 files changed, 110 insertions(+), 103 deletions(-) diff --git a/README.md b/README.md index fa1e0f38..d447fc65 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ Choose your path: - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. -- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. +- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab, etc. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. @@ -88,18 +88,20 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog -[25/01/15] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR. +[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage. -[25/01/15] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. +[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR. + +[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. [25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. +
Full Changelog + [24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details. [24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset. -
Full Changelog - [24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage. [24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models. @@ -378,15 +380,15 @@ huggingface-cli login \* *estimated* -| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B | -| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ | -| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB | -| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB | -| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB | -| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB | -| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB | -| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB | -| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB | +| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B | +| ------------------------ | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ | +| Full | 32 | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB | +| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB | +| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB | +| LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB | +| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB | +| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB | +| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB | ## Getting Started @@ -401,7 +403,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, swanlab, quality +Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality > [!TIP] > Use `pip install --no-deps -e .` to resolve package conflicts. diff --git a/README_zh.md b/README_zh.md index 0f42fd44..b4e5f21d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -67,7 +67,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 - **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 -- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 +- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 - **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、SwanLab 等等。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 @@ -89,18 +89,20 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 更新日志 -[25/01/15] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR. +[25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。 -[25/01/15] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。 +[25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR. + +[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。 [25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。 +
展开日志 + [24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。 [24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。 -
展开日志 - [24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。 [24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。 @@ -246,7 +248,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 训练方法 | 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA | -| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | +| --------------------- | ------------------ | ------------------ | ------------------ | ------------------ | | 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | @@ -379,15 +381,15 @@ huggingface-cli login \* *估算值* -| 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B | -| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ | -| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB | -| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB | -| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB | -| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB | -| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB | -| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB | -| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB | +| 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B | +| ------------------------ | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ | +| Full | 32 | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB | +| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB | +| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB | +| LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB | +| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB | +| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB | +| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB | ## 如何使用 @@ -402,7 +404,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、swanlab、quality +可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality > [!TIP] > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 diff --git a/examples/README.md b/examples/README.md index e589b980..1b944122 100644 --- a/examples/README.md +++ b/examples/README.md @@ -204,6 +204,12 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml ``` +#### Full-Parameter Fine-Tuning using APOLLO + +```bash +llamafactory-cli train examples/extras/apollo/llama3_full_sft.yaml +``` + #### Full-Parameter Fine-Tuning using BAdam ```bash diff --git a/examples/README_zh.md b/examples/README_zh.md index b75a6239..31d3eda2 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -204,6 +204,12 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml ``` +#### 使用 APOLLO 进行全参数训练 + +```bash +llamafactory-cli train examples/extras/apollo/llama3_full_sft.yaml +``` + #### 使用 BAdam 进行全参数训练 ```bash diff --git a/examples/extras/apollo/llama3_full_sft.yaml b/examples/extras/apollo/llama3_full_sft.yaml index c90a0147..520c528e 100644 --- a/examples/extras/apollo/llama3_full_sft.yaml +++ b/examples/extras/apollo/llama3_full_sft.yaml @@ -7,8 +7,8 @@ stage: sft do_train: true finetuning_type: full use_apollo: true -apollo_layerwise: true -apollo_target: mlp,self_attn +apollo_layerwise: true # choices: [true, false], use false for DDP training +apollo_target: all apollo_rank: 128 apollo_scale: 32.0 apollo_scale_type: channel @@ -22,7 +22,7 @@ overwrite_cache: true preprocessing_num_workers: 16 ### output -output_dir: saves/llama3-8b/apollo_full-scale32/sft +output_dir: saves/llama3-8b/full/sft logging_steps: 10 save_steps: 500 plot_loss: true @@ -30,7 +30,7 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 1 # use 1 for layerwise apollo learning_rate: 1.0e-5 num_train_epochs: 3.0 lr_scheduler_type: cosine diff --git a/examples/extras/galore/llama3_full_sft.yaml b/examples/extras/galore/llama3_full_sft.yaml index 02fcf89f..f46f9ae6 100644 --- a/examples/extras/galore/llama3_full_sft.yaml +++ b/examples/extras/galore/llama3_full_sft.yaml @@ -7,8 +7,8 @@ stage: sft do_train: true finetuning_type: full use_galore: true -galore_layerwise: true -galore_target: mlp,self_attn +galore_layerwise: true # choices: [true, false], use false for DDP training +galore_target: all galore_rank: 128 galore_scale: 2.0 @@ -29,7 +29,7 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 1 # use 1 for layerwise galore learning_rate: 1.0e-5 num_train_epochs: 3.0 lr_scheduler_type: cosine diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 3dda9560..e3ddbbe9 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -49,9 +49,11 @@ def is_fastapi_available(): def is_galore_available(): return _is_package_available("galore_torch") + def is_apollo_available(): return _is_package_available("apollo_torch") + def is_gradio_available(): return _is_package_available("gradio") diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 9996ef51..835c9785 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -286,7 +286,7 @@ class ApolloArguments: default="random", metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."}, ) - apollo_proj_type: Literal["std", "right", "left",] = field( + apollo_proj_type: Literal["std", "right", "left"] = field( default="std", metadata={"help": "Type of APOLLO projection."}, ) @@ -475,17 +475,11 @@ class FinetuningArguments( if self.use_llama_pro and self.finetuning_type == "full": raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.") - if self.finetuning_type == "lora" and (self.use_galore or self.use_badam or self.use_apollo): - raise ValueError("Cannot use LoRA with GaLore or BAdam together.") + if self.finetuning_type == "lora" and (self.use_galore or self.use_apollo or self.use_badam): + raise ValueError("Cannot use LoRA with GaLore, APOLLO or BAdam together.") - if self.use_galore and self.use_badam: - raise ValueError("Cannot use GaLore with BAdam together.") - - if self.use_galore and self.use_apollo: - raise ValueError("Cannot use GaLore with APOLLO together.") - - if self.use_badam and self.use_apollo: - raise ValueError("Cannot use BAdam with APOLLO together.") + if int(self.use_galore) + int(self.use_apollo) + (self.use_badam) > 1: + raise ValueError("Cannot use GaLore, APOLLO or BAdam together.") if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model): raise ValueError("Cannot use PiSSA for current training stage.") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index a7c2f3ec..3ffaa1f1 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -258,31 +258,21 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ if is_deepspeed_zero3_enabled(): raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.") - if ( - finetuning_args.use_galore - and finetuning_args.galore_layerwise - and training_args.parallel_mode == ParallelMode.DISTRIBUTED - ): - raise ValueError("Distributed training does not support layer-wise GaLore.") + if training_args.parallel_mode == ParallelMode.DISTRIBUTED: + if finetuning_args.use_galore and finetuning_args.galore_layerwise: + raise ValueError("Distributed training does not support layer-wise GaLore.") - if ( - finetuning_args.use_apollo - and finetuning_args.apollo_layerwise - and training_args.parallel_mode == ParallelMode.DISTRIBUTED - ): - raise ValueError("Distributed training does not support layer-wise APOLLO.") + if finetuning_args.use_apollo and finetuning_args.apollo_layerwise: + raise ValueError("Distributed training does not support layer-wise APOLLO.") - if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED: - if finetuning_args.badam_mode == "ratio": - raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.") - elif not is_deepspeed_zero3_enabled(): - raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.") + if finetuning_args.use_badam: + if finetuning_args.badam_mode == "ratio": + raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.") + elif not is_deepspeed_zero3_enabled(): + raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.") - if finetuning_args.use_galore and training_args.deepspeed is not None: - raise ValueError("GaLore is incompatible with DeepSpeed yet.") - - if finetuning_args.use_apollo and training_args.deepspeed is not None: - raise ValueError("APOLLO is incompatible with DeepSpeed yet.") + if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo): + raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.") if model_args.infer_backend == "vllm": raise ValueError("vLLM backend is only available for API, CLI and Web.") @@ -314,14 +304,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): logger.warning_rank0("We recommend enable mixed precision training.") - if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16: + if ( + training_args.do_train + and (finetuning_args.use_galore or finetuning_args.use_apollo) + and not finetuning_args.pure_bf16 + ): logger.warning_rank0( - "Using GaLore with mixed precision training may significantly increases GPU memory usage." - ) - - if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16: - logger.warning_rank0( - "Using APOLLO with mixed precision training may significantly increases GPU memory usage." + "Using GaLore or APOLLO with mixed precision training may significantly increases GPU memory usage." ) if (not training_args.do_train) and model_args.quantization_bit is not None: @@ -397,7 +386,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ str(model_args.compute_dtype), ) ) - transformers.set_seed(training_args.seed) return model_args, data_args, training_args, finetuning_args, generating_args diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index 6d626f33..f3228638 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -27,7 +27,7 @@ logger = logging.get_logger(__name__) def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: r""" - Finds all available modules to apply lora or galore or apollo. + Finds all available modules to apply LoRA, GaLore or APOLLO. """ model_type = getattr(model.config, "model_type", None) forbidden_modules = {"lm_head"} diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 832b084e..a7d89d8f 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -32,7 +32,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import IGNORE_INDEX -from ..extras.packages import is_galore_available, is_ray_available, is_apollo_available +from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available from ..hparams import FinetuningArguments, ModelArguments from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params @@ -40,9 +40,11 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va if is_galore_available(): from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore + if is_apollo_available(): from apollo_torch import APOLLOAdamW # type: ignore + if is_ray_available(): from ray.train import RunConfig, ScalingConfig from ray.train.torch import TorchTrainer @@ -240,9 +242,10 @@ def _create_galore_optimizer( elif training_args.optim == "adafactor": optim_class = GaLoreAdafactor else: - raise NotImplementedError(f"Unknow optim: {training_args.optim}") + raise NotImplementedError(f"Unknown optim: {training_args.optim}.") if finetuning_args.galore_layerwise: + logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise GaLore.") if training_args.gradient_accumulation_steps != 1: raise ValueError("Per-layer GaLore does not support gradient accumulation.") @@ -274,9 +277,13 @@ def _create_galore_optimizer( ] optimizer = optim_class(param_groups, **optim_kwargs) - logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") + logger.info_rank0( + f"Using GaLore optimizer with args: {galore_kwargs}. " + "It may cause hanging at the start of training, wait patiently." + ) return optimizer + def _create_apollo_optimizer( model: "PreTrainedModel", training_args: "TrainingArguments", @@ -304,11 +311,9 @@ def _create_apollo_optimizer( "scale_front": finetuning_args.apollo_scale_front, } - print(apollo_kwargs) - id_apollo_params = {id(param) for param in apollo_params} - decay_params, nodecay_params = [], [] # they are non-galore parameters - trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params + decay_params, nodecay_params = [], [] # they are non-apollo parameters + trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params decay_param_names = _get_decay_parameter_names(model) for name, param in model.named_parameters(): if param.requires_grad: @@ -324,9 +329,10 @@ def _create_apollo_optimizer( if training_args.optim == "adamw_torch": optim_class = APOLLOAdamW else: - raise NotImplementedError(f"Unknow optim: {training_args.optim}") + raise NotImplementedError(f"Unknown optim: {training_args.optim}.") if finetuning_args.apollo_layerwise: + logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise APOLLO.") if training_args.gradient_accumulation_steps != 1: raise ValueError("Per-layer APOLLO does not support gradient accumulation.") @@ -337,7 +343,7 @@ def _create_apollo_optimizer( for param in decay_params: param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)] optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) - for param in apollo_params: # galore params have weight decay + for param in apollo_params: # apollo params have weight decay param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)] optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) @@ -358,7 +364,7 @@ def _create_apollo_optimizer( ] optimizer = optim_class(param_groups, **optim_kwargs) - logger.info_rank0("Using APOLLO optimizer.") + logger.info_rank0(f"Using APOLLO optimizer with args: {apollo_kwargs}.") return optimizer diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 62bd66c5..ae3c416c 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -234,8 +234,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): use_galore = gr.Checkbox() galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1) - galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1) - galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01) + galore_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1) + galore_scale = gr.Slider(minimum=0, maximum=100, value=2.0, step=0.1) galore_target = gr.Textbox(value="all") input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target}) @@ -254,9 +254,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): use_apollo = gr.Checkbox() apollo_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1) - apollo_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1) - apollo_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01) + apollo_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1) + apollo_scale = gr.Slider(minimum=0, maximum=100, value=32.0, step=0.1) apollo_target = gr.Textbox(value="all") + input_elems.update({use_apollo, apollo_rank, apollo_update_interval, apollo_scale, apollo_target}) elem_dict.update( dict( diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 1dd1810b..51662092 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1162,19 +1162,19 @@ LOCALES = { "use_galore": { "en": { "label": "Use GaLore", - "info": "Enable gradient low-Rank projection.", + "info": "Use GaLore optimizer.", }, "ru": { "label": "Использовать GaLore", - "info": "Включить проекцию градиента на низкоранговое пространство.", + "info": "Используйте оптимизатор GaLore.", }, "zh": { "label": "使用 GaLore", - "info": "使用梯度低秩投影。", + "info": "使用 GaLore 优化器。", }, "ko": { "label": "GaLore 사용", - "info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.", + "info": "GaLore 최적화를 사용하세요.", }, }, "galore_rank": { @@ -1266,19 +1266,19 @@ LOCALES = { "use_apollo": { "en": { "label": "Use APOLLO", - "info": "Enable gradient low-Rank projection.", + "info": "Use APOLLO optimizer.", }, "ru": { "label": "Использовать APOLLO", - "info": "Включить проекцию градиента на низкоранговое пространство.", + "info": "Используйте оптимизатор APOLLO.", }, "zh": { "label": "使用 APOLLO", - "info": "使用梯度低秩投影。", + "info": "使用 APOLLO 优化器。", }, "ko": { "label": "APOLLO 사용", - "info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.", + "info": "APOLLO 최적화를 사용하세요.", }, }, "apollo_rank": { diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index bc010992..c397416d 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -224,7 +224,7 @@ class Runner: args["galore_update_interval"] = get("train.galore_update_interval") args["galore_scale"] = get("train.galore_scale") args["galore_target"] = get("train.galore_target") - + # apollo config if args["use_apollo"]: args["apollo_rank"] = get("train.apollo_rank")