[optim] clean apollo (#6645)

* clean apollo code

* update readme

Former-commit-id: 7a04021d0461caea2c7b82169839340b7f51f463
This commit is contained in:
hoshi-hiyouga 2025-01-15 01:42:50 +08:00 committed by GitHub
parent 763f9b9df0
commit 9ef85f8fc4
14 changed files with 110 additions and 103 deletions

View File

@ -66,7 +66,7 @@ Choose your path:
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. - **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@ -88,18 +88,20 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[25/01/15] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR. [25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
[25/01/15] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. [25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. [25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
<details><summary>Full Changelog</summary>
[24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details. [24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
[24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset. [24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset.
<details><summary>Full Changelog</summary>
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage. [24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
[24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models. [24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
@ -378,15 +380,15 @@ huggingface-cli login
\* *estimated* \* *estimated*
| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B | | Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ | | ------------------------ | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB | | Full | 32 | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB | | Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB | | Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB | | LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
## Getting Started ## Getting Started
@ -401,7 +403,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, swanlab, quality Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality
> [!TIP] > [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts. > Use `pip install --no-deps -e .` to resolve package conflicts.

View File

@ -67,7 +67,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 - **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 - **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow、SwanLab 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow、SwanLab 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@ -89,18 +89,20 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 更新日志 ## 更新日志
[25/01/15] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR. [25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
[25/01/15] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。 [25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR.
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。 [25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
<details><summary>展开日志</summary>
[24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。 [24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。
[24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。 [24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。
<details><summary>展开日志</summary>
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。 [24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。 [24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
@ -246,7 +248,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 训练方法 ## 训练方法
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA | | 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | | --------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
@ -379,15 +381,15 @@ huggingface-cli login
\* *估算值* \* *估算值*
| 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B | | 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ | | ------------------------ | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB | | Full | 32 | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB | | Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB | | Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB | | LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
## 如何使用 ## 如何使用
@ -402,7 +404,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、swanlab、quality 可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality
> [!TIP] > [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。

View File

@ -204,6 +204,12 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
``` ```
#### Full-Parameter Fine-Tuning using APOLLO
```bash
llamafactory-cli train examples/extras/apollo/llama3_full_sft.yaml
```
#### Full-Parameter Fine-Tuning using BAdam #### Full-Parameter Fine-Tuning using BAdam
```bash ```bash

View File

@ -204,6 +204,12 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
``` ```
#### 使用 APOLLO 进行全参数训练
```bash
llamafactory-cli train examples/extras/apollo/llama3_full_sft.yaml
```
#### 使用 BAdam 进行全参数训练 #### 使用 BAdam 进行全参数训练
```bash ```bash

View File

@ -7,8 +7,8 @@ stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
use_apollo: true use_apollo: true
apollo_layerwise: true apollo_layerwise: true # choices: [true, false], use false for DDP training
apollo_target: mlp,self_attn apollo_target: all
apollo_rank: 128 apollo_rank: 128
apollo_scale: 32.0 apollo_scale: 32.0
apollo_scale_type: channel apollo_scale_type: channel
@ -22,7 +22,7 @@ overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
### output ### output
output_dir: saves/llama3-8b/apollo_full-scale32/sft output_dir: saves/llama3-8b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
@ -30,7 +30,7 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1 # use 1 for layerwise apollo
learning_rate: 1.0e-5 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine

View File

@ -7,8 +7,8 @@ stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
use_galore: true use_galore: true
galore_layerwise: true galore_layerwise: true # choices: [true, false], use false for DDP training
galore_target: mlp,self_attn galore_target: all
galore_rank: 128 galore_rank: 128
galore_scale: 2.0 galore_scale: 2.0
@ -29,7 +29,7 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1 # use 1 for layerwise galore
learning_rate: 1.0e-5 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine

View File

@ -49,9 +49,11 @@ def is_fastapi_available():
def is_galore_available(): def is_galore_available():
return _is_package_available("galore_torch") return _is_package_available("galore_torch")
def is_apollo_available(): def is_apollo_available():
return _is_package_available("apollo_torch") return _is_package_available("apollo_torch")
def is_gradio_available(): def is_gradio_available():
return _is_package_available("gradio") return _is_package_available("gradio")

View File

@ -286,7 +286,7 @@ class ApolloArguments:
default="random", default="random",
metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."}, metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
) )
apollo_proj_type: Literal["std", "right", "left",] = field( apollo_proj_type: Literal["std", "right", "left"] = field(
default="std", default="std",
metadata={"help": "Type of APOLLO projection."}, metadata={"help": "Type of APOLLO projection."},
) )
@ -475,17 +475,11 @@ class FinetuningArguments(
if self.use_llama_pro and self.finetuning_type == "full": if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.") raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam or self.use_apollo): if self.finetuning_type == "lora" and (self.use_galore or self.use_apollo or self.use_badam):
raise ValueError("Cannot use LoRA with GaLore or BAdam together.") raise ValueError("Cannot use LoRA with GaLore, APOLLO or BAdam together.")
if self.use_galore and self.use_badam: if int(self.use_galore) + int(self.use_apollo) + (self.use_badam) > 1:
raise ValueError("Cannot use GaLore with BAdam together.") raise ValueError("Cannot use GaLore, APOLLO or BAdam together.")
if self.use_galore and self.use_apollo:
raise ValueError("Cannot use GaLore with APOLLO together.")
if self.use_badam and self.use_apollo:
raise ValueError("Cannot use BAdam with APOLLO together.")
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model): if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.") raise ValueError("Cannot use PiSSA for current training stage.")

View File

@ -258,31 +258,21 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.") raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
if ( if training_args.parallel_mode == ParallelMode.DISTRIBUTED:
finetuning_args.use_galore if finetuning_args.use_galore and finetuning_args.galore_layerwise:
and finetuning_args.galore_layerwise raise ValueError("Distributed training does not support layer-wise GaLore.")
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
raise ValueError("Distributed training does not support layer-wise GaLore.")
if ( if finetuning_args.use_apollo and finetuning_args.apollo_layerwise:
finetuning_args.use_apollo raise ValueError("Distributed training does not support layer-wise APOLLO.")
and finetuning_args.apollo_layerwise
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
raise ValueError("Distributed training does not support layer-wise APOLLO.")
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED: if finetuning_args.use_badam:
if finetuning_args.badam_mode == "ratio": if finetuning_args.badam_mode == "ratio":
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.") raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
elif not is_deepspeed_zero3_enabled(): elif not is_deepspeed_zero3_enabled():
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.") raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
if finetuning_args.use_galore and training_args.deepspeed is not None: if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore is incompatible with DeepSpeed yet.") raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if finetuning_args.use_apollo and training_args.deepspeed is not None:
raise ValueError("APOLLO is incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM backend is only available for API, CLI and Web.")
@ -314,14 +304,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning_rank0("We recommend enable mixed precision training.") logger.warning_rank0("We recommend enable mixed precision training.")
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16: if (
training_args.do_train
and (finetuning_args.use_galore or finetuning_args.use_apollo)
and not finetuning_args.pure_bf16
):
logger.warning_rank0( logger.warning_rank0(
"Using GaLore with mixed precision training may significantly increases GPU memory usage." "Using GaLore or APOLLO with mixed precision training may significantly increases GPU memory usage."
)
if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16:
logger.warning_rank0(
"Using APOLLO with mixed precision training may significantly increases GPU memory usage."
) )
if (not training_args.do_train) and model_args.quantization_bit is not None: if (not training_args.do_train) and model_args.quantization_bit is not None:
@ -397,7 +386,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
str(model_args.compute_dtype), str(model_args.compute_dtype),
) )
) )
transformers.set_seed(training_args.seed) transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args return model_args, data_args, training_args, finetuning_args, generating_args

View File

@ -27,7 +27,7 @@ logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r""" r"""
Finds all available modules to apply lora or galore or apollo. Finds all available modules to apply LoRA, GaLore or APOLLO.
""" """
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"} forbidden_modules = {"lm_head"}

View File

@ -32,7 +32,7 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available, is_ray_available, is_apollo_available from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@ -40,9 +40,11 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
if is_galore_available(): if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_apollo_available(): if is_apollo_available():
from apollo_torch import APOLLOAdamW # type: ignore from apollo_torch import APOLLOAdamW # type: ignore
if is_ray_available(): if is_ray_available():
from ray.train import RunConfig, ScalingConfig from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer from ray.train.torch import TorchTrainer
@ -240,9 +242,10 @@ def _create_galore_optimizer(
elif training_args.optim == "adafactor": elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor optim_class = GaLoreAdafactor
else: else:
raise NotImplementedError(f"Unknow optim: {training_args.optim}") raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.galore_layerwise: if finetuning_args.galore_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise GaLore.")
if training_args.gradient_accumulation_steps != 1: if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.") raise ValueError("Per-layer GaLore does not support gradient accumulation.")
@ -274,9 +277,13 @@ def _create_galore_optimizer(
] ]
optimizer = optim_class(param_groups, **optim_kwargs) optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") logger.info_rank0(
f"Using GaLore optimizer with args: {galore_kwargs}. "
"It may cause hanging at the start of training, wait patiently."
)
return optimizer return optimizer
def _create_apollo_optimizer( def _create_apollo_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "TrainingArguments", training_args: "TrainingArguments",
@ -304,11 +311,9 @@ def _create_apollo_optimizer(
"scale_front": finetuning_args.apollo_scale_front, "scale_front": finetuning_args.apollo_scale_front,
} }
print(apollo_kwargs)
id_apollo_params = {id(param) for param in apollo_params} id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters decay_params, nodecay_params = [], [] # they are non-apollo parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model) decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.requires_grad: if param.requires_grad:
@ -324,9 +329,10 @@ def _create_apollo_optimizer(
if training_args.optim == "adamw_torch": if training_args.optim == "adamw_torch":
optim_class = APOLLOAdamW optim_class = APOLLOAdamW
else: else:
raise NotImplementedError(f"Unknow optim: {training_args.optim}") raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.apollo_layerwise: if finetuning_args.apollo_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise APOLLO.")
if training_args.gradient_accumulation_steps != 1: if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.") raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
@ -337,7 +343,7 @@ def _create_apollo_optimizer(
for param in decay_params: for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)] param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in apollo_params: # galore params have weight decay for param in apollo_params: # apollo params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)] param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
@ -358,7 +364,7 @@ def _create_apollo_optimizer(
] ]
optimizer = optim_class(param_groups, **optim_kwargs) optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0("Using APOLLO optimizer.") logger.info_rank0(f"Using APOLLO optimizer with args: {apollo_kwargs}.")
return optimizer return optimizer

View File

@ -234,8 +234,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
use_galore = gr.Checkbox() use_galore = gr.Checkbox()
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1) galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1) galore_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1)
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01) galore_scale = gr.Slider(minimum=0, maximum=100, value=2.0, step=0.1)
galore_target = gr.Textbox(value="all") galore_target = gr.Textbox(value="all")
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target}) input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
@ -254,9 +254,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
use_apollo = gr.Checkbox() use_apollo = gr.Checkbox()
apollo_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1) apollo_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
apollo_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1) apollo_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1)
apollo_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01) apollo_scale = gr.Slider(minimum=0, maximum=100, value=32.0, step=0.1)
apollo_target = gr.Textbox(value="all") apollo_target = gr.Textbox(value="all")
input_elems.update({use_apollo, apollo_rank, apollo_update_interval, apollo_scale, apollo_target}) input_elems.update({use_apollo, apollo_rank, apollo_update_interval, apollo_scale, apollo_target})
elem_dict.update( elem_dict.update(
dict( dict(

View File

@ -1162,19 +1162,19 @@ LOCALES = {
"use_galore": { "use_galore": {
"en": { "en": {
"label": "Use GaLore", "label": "Use GaLore",
"info": "Enable gradient low-Rank projection.", "info": "Use GaLore optimizer.",
}, },
"ru": { "ru": {
"label": "Использовать GaLore", "label": "Использовать GaLore",
"info": "Включить проекцию градиента на низкоранговое пространство.", "info": "Используйте оптимизатор GaLore.",
}, },
"zh": { "zh": {
"label": "使用 GaLore", "label": "使用 GaLore",
"info": "使用梯度低秩投影", "info": "使用 GaLore 优化器",
}, },
"ko": { "ko": {
"label": "GaLore 사용", "label": "GaLore 사용",
"info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.", "info": "GaLore 최적화를 사용하세요.",
}, },
}, },
"galore_rank": { "galore_rank": {
@ -1266,19 +1266,19 @@ LOCALES = {
"use_apollo": { "use_apollo": {
"en": { "en": {
"label": "Use APOLLO", "label": "Use APOLLO",
"info": "Enable gradient low-Rank projection.", "info": "Use APOLLO optimizer.",
}, },
"ru": { "ru": {
"label": "Использовать APOLLO", "label": "Использовать APOLLO",
"info": "Включить проекцию градиента на низкоранговое пространство.", "info": "Используйте оптимизатор APOLLO.",
}, },
"zh": { "zh": {
"label": "使用 APOLLO", "label": "使用 APOLLO",
"info": "使用梯度低秩投影", "info": "使用 APOLLO 优化器",
}, },
"ko": { "ko": {
"label": "APOLLO 사용", "label": "APOLLO 사용",
"info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.", "info": "APOLLO 최적화를 사용하세요.",
}, },
}, },
"apollo_rank": { "apollo_rank": {

View File

@ -224,7 +224,7 @@ class Runner:
args["galore_update_interval"] = get("train.galore_update_interval") args["galore_update_interval"] = get("train.galore_update_interval")
args["galore_scale"] = get("train.galore_scale") args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target") args["galore_target"] = get("train.galore_target")
# apollo config # apollo config
if args["use_apollo"]: if args["use_apollo"]:
args["apollo_rank"] = get("train.apollo_rank") args["apollo_rank"] = get("train.apollo_rank")