mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[feature] adding orthogononal finetuning (OFT) to llama factory (#8623)
Co-authored-by: Zeju <zqiu@g003.internal.cluster.is.localnet> Co-authored-by: Zeju <zqiu@login2.is.localnet> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
		
							parent
							
								
									1ada15981a
								
							
						
					
					
						commit
						003a2acb1a
					
				
							
								
								
									
										38
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								README.md
									
									
									
									
									
								
							@ -86,7 +86,7 @@ Choose your path:
 | 
			
		||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
 | 
			
		||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
 | 
			
		||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
 | 
			
		||||
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
 | 
			
		||||
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT] (https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
 | 
			
		||||
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
 | 
			
		||||
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
 | 
			
		||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
 | 
			
		||||
@ -329,16 +329,16 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
 | 
			
		||||
 | 
			
		||||
## Supported Training Approaches
 | 
			
		||||
 | 
			
		||||
| Approach               |     Full-tuning    |    Freeze-tuning   |       LoRA         |       QLoRA        |
 | 
			
		||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
 | 
			
		||||
| Pre-Training           | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| Reward Modeling        | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| PPO Training           | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| DPO Training           | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| KTO Training           | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| ORPO Training          | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| SimPO Training         | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| Approach               |     Full-tuning    |    Freeze-tuning   |       LoRA         |       QLoRA        |        OFT         |        QOFT        |
 | 
			
		||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
 | 
			
		||||
| Pre-Training           | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| Reward Modeling        | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| PPO Training           | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| DPO Training           | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| KTO Training           | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| ORPO Training          | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| SimPO Training         | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
 | 
			
		||||
> [!TIP]
 | 
			
		||||
> The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
 | 
			
		||||
@ -469,14 +469,14 @@ huggingface-cli login
 | 
			
		||||
 | 
			
		||||
\* *estimated*
 | 
			
		||||
 | 
			
		||||
| Method                          | Bits |   7B  |  14B  |  30B  |   70B  |   `x`B  |
 | 
			
		||||
| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
 | 
			
		||||
| Full (`bf16` or `fp16`)         |  32  | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
 | 
			
		||||
| Full (`pure_bf16`)              |  16  |  60GB | 120GB | 300GB |  600GB |  `8x`GB |
 | 
			
		||||
| Freeze/LoRA/GaLore/APOLLO/BAdam |  16  |  16GB |  32GB |  64GB |  160GB |  `2x`GB |
 | 
			
		||||
| QLoRA                           |   8  |  10GB |  20GB |  40GB |   80GB |   `x`GB |
 | 
			
		||||
| QLoRA                           |   4  |   6GB |  12GB |  24GB |   48GB | `x/2`GB |
 | 
			
		||||
| QLoRA                           |   2  |   4GB |   8GB |  16GB |   24GB | `x/4`GB |
 | 
			
		||||
| Method                              | Bits |   7B  |  14B  |  30B  |   70B  |   `x`B  |
 | 
			
		||||
| ----------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
 | 
			
		||||
| Full (`bf16` or `fp16`)             |  32  | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
 | 
			
		||||
| Full (`pure_bf16`)                  |  16  |  60GB | 120GB | 300GB |  600GB |  `8x`GB |
 | 
			
		||||
| Freeze/LoRA/GaLore/APOLLO/BAdam/OFT |  16  |  16GB |  32GB |  64GB |  160GB |  `2x`GB |
 | 
			
		||||
| QLoRA / QOFT                        |   8  |  10GB |  20GB |  40GB |   80GB |   `x`GB |
 | 
			
		||||
| QLoRA / QOFT                        |   4  |   6GB |  12GB |  24GB |   48GB | `x/2`GB |
 | 
			
		||||
| QLoRA / QOFT                        |   2  |   4GB |   8GB |  16GB |   24GB | `x/4`GB |
 | 
			
		||||
 | 
			
		||||
## Getting Started
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -290,3 +290,15 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
 | 
			
		||||
```bash
 | 
			
		||||
bash examples/extras/fsdp_qlora/train.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### OFT Fine-Tuning
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/extras/oft/llama3_oft_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### QOFT Fine-Tuning
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@ -290,3 +290,15 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
 | 
			
		||||
```bash
 | 
			
		||||
bash examples/extras/fsdp_qlora/train.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### OFT 微调
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/extras/oft/llama3_oft_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### QOFT 微调
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										46
									
								
								examples/extras/oft/llama3_oft_sft.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								examples/extras/oft/llama3_oft_sft.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,46 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: oft
 | 
			
		||||
oft_block_size: 32
 | 
			
		||||
oft_target: all
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: identity,alpaca_en_demo
 | 
			
		||||
template: llama3
 | 
			
		||||
cutoff_len: 2048
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
dataloader_num_workers: 4
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: saves/llama3-8b/oft/sft
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
save_only_model: false
 | 
			
		||||
report_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 8
 | 
			
		||||
learning_rate: 1.0e-4
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
bf16: true
 | 
			
		||||
ddp_timeout: 180000000
 | 
			
		||||
resume_from_checkpoint: null
 | 
			
		||||
 | 
			
		||||
### eval
 | 
			
		||||
# eval_dataset: alpaca_en_demo
 | 
			
		||||
# val_size: 0.1
 | 
			
		||||
# per_device_eval_batch_size: 1
 | 
			
		||||
# eval_strategy: steps
 | 
			
		||||
# eval_steps: 500
 | 
			
		||||
							
								
								
									
										47
									
								
								examples/extras/oft/qwen2_5vl_oft_sft.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								examples/extras/oft/qwen2_5vl_oft_sft.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,47 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
 | 
			
		||||
image_max_pixels: 262144
 | 
			
		||||
video_max_pixels: 16384
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: oft
 | 
			
		||||
oft_block_size: 32
 | 
			
		||||
oft_target: all
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: mllm_demo,identity,alpaca_en_demo  # video: mllm_video_demo
 | 
			
		||||
template: qwen2_vl
 | 
			
		||||
cutoff_len: 2048
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
dataloader_num_workers: 4
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: saves/qwen2_5vl-7b/oft/sft
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
save_only_model: false
 | 
			
		||||
report_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 8
 | 
			
		||||
learning_rate: 1.0e-4
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
bf16: true
 | 
			
		||||
ddp_timeout: 180000000
 | 
			
		||||
resume_from_checkpoint: null
 | 
			
		||||
 | 
			
		||||
### eval
 | 
			
		||||
# val_size: 0.1
 | 
			
		||||
# per_device_eval_batch_size: 1
 | 
			
		||||
# eval_strategy: steps
 | 
			
		||||
# eval_steps: 500
 | 
			
		||||
							
								
								
									
										44
									
								
								examples/extras/qoft/llama3_oft_sft_awq.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								examples/extras/qoft/llama3_oft_sft_awq.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,44 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: oft
 | 
			
		||||
oft_block_size: 32
 | 
			
		||||
oft_target: all
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: identity,alpaca_en_demo
 | 
			
		||||
template: llama3
 | 
			
		||||
cutoff_len: 2048
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
dataloader_num_workers: 4
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: saves/llama3-8b/oft/sft
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
save_only_model: false
 | 
			
		||||
report_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 8
 | 
			
		||||
learning_rate: 1.0e-4
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
bf16: true
 | 
			
		||||
ddp_timeout: 180000000
 | 
			
		||||
 | 
			
		||||
### eval
 | 
			
		||||
# val_size: 0.1
 | 
			
		||||
# per_device_eval_batch_size: 1
 | 
			
		||||
# eval_strategy: steps
 | 
			
		||||
# eval_steps: 500
 | 
			
		||||
							
								
								
									
										47
									
								
								examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,47 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
 | 
			
		||||
quantization_bit: 4
 | 
			
		||||
quantization_method: bnb
 | 
			
		||||
double_quantization: false
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: oft
 | 
			
		||||
oft_block_size: 32
 | 
			
		||||
oft_target: all
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: identity,alpaca_en_demo
 | 
			
		||||
template: llama3
 | 
			
		||||
cutoff_len: 2048
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
dataloader_num_workers: 4
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: saves/llama3-8b/oft/sft
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
save_only_model: false
 | 
			
		||||
report_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 8
 | 
			
		||||
learning_rate: 1.0e-4
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
bf16: true
 | 
			
		||||
ddp_timeout: 180000000
 | 
			
		||||
 | 
			
		||||
### eval
 | 
			
		||||
# val_size: 0.1
 | 
			
		||||
# per_device_eval_batch_size: 1
 | 
			
		||||
# eval_strategy: steps
 | 
			
		||||
# eval_steps: 500
 | 
			
		||||
							
								
								
									
										44
									
								
								examples/extras/qoft/llama3_oft_sft_gptq.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								examples/extras/qoft/llama3_oft_sft_gptq.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,44 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: oft
 | 
			
		||||
oft_block_size: 32
 | 
			
		||||
oft_target: all
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: identity,alpaca_en_demo
 | 
			
		||||
template: llama3
 | 
			
		||||
cutoff_len: 2048
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
dataloader_num_workers: 4
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: saves/llama3-8b/oft/sft
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
save_only_model: false
 | 
			
		||||
report_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 8
 | 
			
		||||
learning_rate: 1.0e-4
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
bf16: true
 | 
			
		||||
ddp_timeout: 180000000
 | 
			
		||||
 | 
			
		||||
### eval
 | 
			
		||||
# val_size: 0.1
 | 
			
		||||
# per_device_eval_batch_size: 1
 | 
			
		||||
# eval_strategy: steps
 | 
			
		||||
# eval_steps: 500
 | 
			
		||||
@ -56,13 +56,13 @@ LAYERNORM_NAMES = {"norm", "ln"}
 | 
			
		||||
 | 
			
		||||
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
 | 
			
		||||
 | 
			
		||||
METHODS = ["full", "freeze", "lora"]
 | 
			
		||||
METHODS = ["full", "freeze", "lora", "oft"]
 | 
			
		||||
 | 
			
		||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
 | 
			
		||||
 | 
			
		||||
MULTIMODAL_SUPPORTED_MODELS = set()
 | 
			
		||||
 | 
			
		||||
PEFT_METHODS = {"lora"}
 | 
			
		||||
PEFT_METHODS = {"lora", "oft"}
 | 
			
		||||
 | 
			
		||||
RUNNING_LOG = "running_log.txt"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -122,6 +122,48 @@ class LoraArguments:
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class OFTArguments:
 | 
			
		||||
    r"""Arguments pertaining to the OFT training."""
 | 
			
		||||
 | 
			
		||||
    additional_target: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": (
 | 
			
		||||
                "Name(s) of modules apart from LoRA layers to be set as trainable "
 | 
			
		||||
                "and saved in the final checkpoint. "
 | 
			
		||||
                "Use commas to separate multiple modules."
 | 
			
		||||
            )
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    module_dropout: float = field(
 | 
			
		||||
        default=0.0,
 | 
			
		||||
        metadata={"help": "Dropout rate for the OFT fine-tuning."},
 | 
			
		||||
    )
 | 
			
		||||
    oft_rank: int = field(
 | 
			
		||||
        default=0,
 | 
			
		||||
        metadata={"help": "The intrinsic dimension for OFT fine-tuning."},
 | 
			
		||||
    )
 | 
			
		||||
    oft_block_size: int = field(
 | 
			
		||||
        default=32,
 | 
			
		||||
        metadata={"help": "The intrinsic dimension for OFT fine-tuning."},
 | 
			
		||||
    )
 | 
			
		||||
    oft_target: str = field(
 | 
			
		||||
        default="all",
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": (
 | 
			
		||||
                "Name(s) of target modules to apply OFT. "
 | 
			
		||||
                "Use commas to separate multiple modules. "
 | 
			
		||||
                "Use `all` to specify all the linear modules."
 | 
			
		||||
            )
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    create_new_adapter: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class RLHFArguments:
 | 
			
		||||
    r"""Arguments pertaining to the PPO, DPO and KTO training."""
 | 
			
		||||
@ -400,7 +442,14 @@ class SwanLabArguments:
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class FinetuningArguments(
 | 
			
		||||
    SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
 | 
			
		||||
    SwanLabArguments,
 | 
			
		||||
    BAdamArgument,
 | 
			
		||||
    ApolloArguments,
 | 
			
		||||
    GaloreArguments,
 | 
			
		||||
    RLHFArguments,
 | 
			
		||||
    LoraArguments,
 | 
			
		||||
    OFTArguments,
 | 
			
		||||
    FreezeArguments,
 | 
			
		||||
):
 | 
			
		||||
    r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
 | 
			
		||||
 | 
			
		||||
@ -475,12 +524,13 @@ class FinetuningArguments(
 | 
			
		||||
        self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
 | 
			
		||||
        self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
 | 
			
		||||
        self.lora_target: list[str] = split_arg(self.lora_target)
 | 
			
		||||
        self.oft_target: list[str] = split_arg(self.oft_target)
 | 
			
		||||
        self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
 | 
			
		||||
        self.galore_target: list[str] = split_arg(self.galore_target)
 | 
			
		||||
        self.apollo_target: list[str] = split_arg(self.apollo_target)
 | 
			
		||||
        self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
 | 
			
		||||
 | 
			
		||||
        assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
 | 
			
		||||
        assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method."
 | 
			
		||||
        assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
 | 
			
		||||
        assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
 | 
			
		||||
 | 
			
		||||
@ -490,6 +540,9 @@ class FinetuningArguments(
 | 
			
		||||
        if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
 | 
			
		||||
            raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
 | 
			
		||||
 | 
			
		||||
        if self.stage == "ppo" and self.reward_model_type == "oft" and self.finetuning_type != "oft":
 | 
			
		||||
            raise ValueError("`reward_model_type` cannot be oft for Freeze/Full PPO training.")
 | 
			
		||||
 | 
			
		||||
        if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
 | 
			
		||||
            raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -111,8 +111,8 @@ def _verify_model_args(
 | 
			
		||||
        raise ValueError("Adapter is only valid for the LoRA method.")
 | 
			
		||||
 | 
			
		||||
    if model_args.quantization_bit is not None:
 | 
			
		||||
        if finetuning_args.finetuning_type != "lora":
 | 
			
		||||
            raise ValueError("Quantization is only compatible with the LoRA method.")
 | 
			
		||||
        if finetuning_args.finetuning_type not in ["lora", "oft"]:
 | 
			
		||||
            raise ValueError("Quantization is only compatible with the LoRA or OFT method.")
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.pissa_init:
 | 
			
		||||
            raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
 | 
			
		||||
 | 
			
		||||
@ -16,10 +16,11 @@ import re
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
 | 
			
		||||
from peft import LoraConfig, LoraModel, OFTConfig, OFTModel, PeftModel, TaskType, get_peft_model
 | 
			
		||||
from transformers.integrations import is_deepspeed_zero3_enabled
 | 
			
		||||
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.misc import check_version
 | 
			
		||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
 | 
			
		||||
from .model_utils.quantization import QuantizationMethod
 | 
			
		||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
 | 
			
		||||
@ -147,7 +148,10 @@ def _setup_lora_tuning(
 | 
			
		||||
    cast_trainable_params_to_fp32: bool,
 | 
			
		||||
) -> "PeftModel":
 | 
			
		||||
    if is_trainable:
 | 
			
		||||
        logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
 | 
			
		||||
        if finetuning_args.finetuning_type == "oft":
 | 
			
		||||
            logger.info_rank0("Fine-tuning method: OFT")
 | 
			
		||||
        else:
 | 
			
		||||
            logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
 | 
			
		||||
 | 
			
		||||
    adapter_to_resume = None
 | 
			
		||||
 | 
			
		||||
@ -223,17 +227,29 @@ def _setup_lora_tuning(
 | 
			
		||||
            finetuning_args.additional_target = module_names
 | 
			
		||||
            logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
 | 
			
		||||
 | 
			
		||||
        peft_kwargs = {
 | 
			
		||||
            "r": finetuning_args.lora_rank,
 | 
			
		||||
            "target_modules": target_modules,
 | 
			
		||||
            "lora_alpha": finetuning_args.lora_alpha,
 | 
			
		||||
            "lora_dropout": finetuning_args.lora_dropout,
 | 
			
		||||
            "use_rslora": finetuning_args.use_rslora,
 | 
			
		||||
            "use_dora": finetuning_args.use_dora,
 | 
			
		||||
            "modules_to_save": finetuning_args.additional_target,
 | 
			
		||||
        }
 | 
			
		||||
        if finetuning_args.finetuning_type == "lora":
 | 
			
		||||
            peft_kwargs = {
 | 
			
		||||
                "r": finetuning_args.lora_rank,
 | 
			
		||||
                "target_modules": target_modules,
 | 
			
		||||
                "lora_alpha": finetuning_args.lora_alpha,
 | 
			
		||||
                "lora_dropout": finetuning_args.lora_dropout,
 | 
			
		||||
                "use_rslora": finetuning_args.use_rslora,
 | 
			
		||||
                "use_dora": finetuning_args.use_dora,
 | 
			
		||||
                "modules_to_save": finetuning_args.additional_target,
 | 
			
		||||
            }
 | 
			
		||||
        elif finetuning_args.finetuning_type == "oft":
 | 
			
		||||
            peft_kwargs = {
 | 
			
		||||
                "r": finetuning_args.oft_rank,
 | 
			
		||||
                "oft_block_size": finetuning_args.oft_block_size,
 | 
			
		||||
                "target_modules": target_modules,
 | 
			
		||||
                "module_dropout": finetuning_args.module_dropout,
 | 
			
		||||
                "modules_to_save": finetuning_args.additional_target,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if model_args.use_unsloth:
 | 
			
		||||
            if finetuning_args.finetuning_type == "oft":
 | 
			
		||||
                raise ValueError("Unsloth is currently not supported for OFT.")
 | 
			
		||||
 | 
			
		||||
            model = get_unsloth_peft_model(model, model_args, peft_kwargs)
 | 
			
		||||
        else:
 | 
			
		||||
            if finetuning_args.pissa_init:
 | 
			
		||||
@ -244,12 +260,19 @@ def _setup_lora_tuning(
 | 
			
		||||
                    logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
 | 
			
		||||
                    peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
 | 
			
		||||
 | 
			
		||||
            lora_config = LoraConfig(
 | 
			
		||||
                task_type=TaskType.CAUSAL_LM,
 | 
			
		||||
                inference_mode=False,
 | 
			
		||||
                **peft_kwargs,
 | 
			
		||||
            )
 | 
			
		||||
            model = get_peft_model(model, lora_config)
 | 
			
		||||
            if finetuning_args.finetuning_type == "lora":
 | 
			
		||||
                peft_config = LoraConfig(
 | 
			
		||||
                    task_type=TaskType.CAUSAL_LM,
 | 
			
		||||
                    inference_mode=False,
 | 
			
		||||
                    **peft_kwargs,
 | 
			
		||||
                )
 | 
			
		||||
            elif finetuning_args.finetuning_type == "oft":
 | 
			
		||||
                peft_config = OFTConfig(
 | 
			
		||||
                    task_type=TaskType.CAUSAL_LM,
 | 
			
		||||
                    inference_mode=False,
 | 
			
		||||
                    **peft_kwargs,
 | 
			
		||||
                )
 | 
			
		||||
            model = get_peft_model(model, peft_config)
 | 
			
		||||
 | 
			
		||||
    if is_trainable and cast_trainable_params_to_fp32:
 | 
			
		||||
        for param in filter(lambda p: p.requires_grad, model.parameters()):
 | 
			
		||||
@ -272,8 +295,8 @@ def init_adapter(
 | 
			
		||||
    Note that the trainable parameters must be cast to float32.
 | 
			
		||||
    """
 | 
			
		||||
    if is_trainable and getattr(model, "quantization_method", None) is not None:
 | 
			
		||||
        if finetuning_args.finetuning_type != "lora":
 | 
			
		||||
            raise ValueError("Quantized models can only be used for the LoRA tuning.")
 | 
			
		||||
        if finetuning_args.finetuning_type not in ["lora", "oft"]:
 | 
			
		||||
            raise ValueError("Quantized models can only be used for the LoRA or OFT tuning.")
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.pissa_init:
 | 
			
		||||
            raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
 | 
			
		||||
@ -296,7 +319,7 @@ def init_adapter(
 | 
			
		||||
        _setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
 | 
			
		||||
    elif finetuning_args.finetuning_type == "freeze":
 | 
			
		||||
        _setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
 | 
			
		||||
    elif finetuning_args.finetuning_type == "lora":
 | 
			
		||||
    elif finetuning_args.finetuning_type in ["lora", "oft"]:
 | 
			
		||||
        model = _setup_lora_tuning(
 | 
			
		||||
            config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -390,7 +390,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
 | 
			
		||||
        batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
 | 
			
		||||
        unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
 | 
			
		||||
 | 
			
		||||
        if self.finetuning_args.reward_model_type == "lora":
 | 
			
		||||
        if self.finetuning_args.reward_model_type in ["lora", "oft"]:
 | 
			
		||||
            replace_model(unwrapped_model, target="reward")
 | 
			
		||||
            reward_model = self.model
 | 
			
		||||
        else:
 | 
			
		||||
@ -399,7 +399,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
 | 
			
		||||
        with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context:  # support bf16
 | 
			
		||||
            values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
 | 
			
		||||
 | 
			
		||||
        if self.finetuning_args.reward_model_type == "lora":
 | 
			
		||||
        if self.finetuning_args.reward_model_type in ["lora", "oft"]:
 | 
			
		||||
            replace_model(unwrapped_model, target="default")
 | 
			
		||||
 | 
			
		||||
        rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user