mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-19 12:12:48 +08:00
[feature] adding orthogononal finetuning (OFT) to llama factory (#8623)
Co-authored-by: Zeju <zqiu@g003.internal.cluster.is.localnet> Co-authored-by: Zeju <zqiu@login2.is.localnet> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
parent
0454c10456
commit
842595698b
32
README.md
32
README.md
@ -86,7 +86,7 @@ Choose your path:
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
||||
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
|
||||
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT] (https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
|
||||
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
|
||||
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
|
||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
|
||||
@ -329,16 +329,16 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
||||
|
||||
## Supported Training Approaches
|
||||
|
||||
| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
|
||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA | OFT | QOFT |
|
||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
> [!TIP]
|
||||
> The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
|
||||
@ -470,13 +470,13 @@ huggingface-cli login
|
||||
\* *estimated*
|
||||
|
||||
| Method | Bits | 7B | 14B | 30B | 70B | `x`B |
|
||||
| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
|
||||
| ----------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
|
||||
| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
|
||||
| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
|
||||
| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
|
||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
|
||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
|
||||
| Freeze/LoRA/GaLore/APOLLO/BAdam/OFT | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
|
||||
| QLoRA / QOFT | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
|
||||
| QLoRA / QOFT | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
|
||||
| QLoRA / QOFT | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
|
||||
|
||||
## Getting Started
|
||||
|
||||
|
@ -290,3 +290,15 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
||||
```bash
|
||||
bash examples/extras/fsdp_qlora/train.sh
|
||||
```
|
||||
|
||||
#### OFT Fine-Tuning
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/extras/oft/llama3_oft_sft.yaml
|
||||
```
|
||||
|
||||
#### QOFT Fine-Tuning
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml
|
||||
```
|
||||
|
@ -290,3 +290,15 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
||||
```bash
|
||||
bash examples/extras/fsdp_qlora/train.sh
|
||||
```
|
||||
|
||||
#### OFT 微调
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/extras/oft/llama3_oft_sft.yaml
|
||||
```
|
||||
|
||||
#### QOFT 微调
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml
|
||||
```
|
||||
|
46
examples/extras/oft/llama3_oft_sft.yaml
Normal file
46
examples/extras/oft/llama3_oft_sft.yaml
Normal file
@ -0,0 +1,46 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: oft
|
||||
oft_block_size: 32
|
||||
oft_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/oft/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### eval
|
||||
# eval_dataset: alpaca_en_demo
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
47
examples/extras/oft/qwen2_5vl_oft_sft.yaml
Normal file
47
examples/extras/oft/qwen2_5vl_oft_sft.yaml
Normal file
@ -0,0 +1,47 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
image_max_pixels: 262144
|
||||
video_max_pixels: 16384
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: oft
|
||||
oft_block_size: 32
|
||||
oft_target: all
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo,identity,alpaca_en_demo # video: mllm_video_demo
|
||||
template: qwen2_vl
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen2_5vl-7b/oft/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
|
||||
### eval
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
44
examples/extras/qoft/llama3_oft_sft_awq.yaml
Normal file
44
examples/extras/qoft/llama3_oft_sft_awq.yaml
Normal file
@ -0,0 +1,44 @@
|
||||
### model
|
||||
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: oft
|
||||
oft_block_size: 32
|
||||
oft_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/oft/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
47
examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml
Normal file
47
examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml
Normal file
@ -0,0 +1,47 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
quantization_bit: 4
|
||||
quantization_method: bnb
|
||||
double_quantization: false
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: oft
|
||||
oft_block_size: 32
|
||||
oft_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/oft/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
44
examples/extras/qoft/llama3_oft_sft_gptq.yaml
Normal file
44
examples/extras/qoft/llama3_oft_sft_gptq.yaml
Normal file
@ -0,0 +1,44 @@
|
||||
### model
|
||||
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: oft
|
||||
oft_block_size: 32
|
||||
oft_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/oft/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
@ -56,13 +56,13 @@ LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
METHODS = ["full", "freeze", "lora", "oft"]
|
||||
|
||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||
|
||||
MULTIMODAL_SUPPORTED_MODELS = set()
|
||||
|
||||
PEFT_METHODS = {"lora"}
|
||||
PEFT_METHODS = {"lora", "oft"}
|
||||
|
||||
RUNNING_LOG = "running_log.txt"
|
||||
|
||||
|
@ -122,6 +122,48 @@ class LoraArguments:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OFTArguments:
|
||||
r"""Arguments pertaining to the OFT training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Name(s) of modules apart from LoRA layers to be set as trainable "
|
||||
"and saved in the final checkpoint. "
|
||||
"Use commas to separate multiple modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
module_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "Dropout rate for the OFT fine-tuning."},
|
||||
)
|
||||
oft_rank: int = field(
|
||||
default=0,
|
||||
metadata={"help": "The intrinsic dimension for OFT fine-tuning."},
|
||||
)
|
||||
oft_block_size: int = field(
|
||||
default=32,
|
||||
metadata={"help": "The intrinsic dimension for OFT fine-tuning."},
|
||||
)
|
||||
oft_target: str = field(
|
||||
default="all",
|
||||
metadata={
|
||||
"help": (
|
||||
"Name(s) of target modules to apply OFT. "
|
||||
"Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the linear modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
create_new_adapter: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLHFArguments:
|
||||
r"""Arguments pertaining to the PPO, DPO and KTO training."""
|
||||
@ -400,7 +442,14 @@ class SwanLabArguments:
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(
|
||||
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
|
||||
SwanLabArguments,
|
||||
BAdamArgument,
|
||||
ApolloArguments,
|
||||
GaloreArguments,
|
||||
RLHFArguments,
|
||||
LoraArguments,
|
||||
OFTArguments,
|
||||
FreezeArguments,
|
||||
):
|
||||
r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
|
||||
|
||||
@ -475,12 +524,13 @@ class FinetuningArguments(
|
||||
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: list[str] = split_arg(self.lora_target)
|
||||
self.oft_target: list[str] = split_arg(self.oft_target)
|
||||
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: list[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
@ -490,6 +540,9 @@ class FinetuningArguments(
|
||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
||||
|
||||
if self.stage == "ppo" and self.reward_model_type == "oft" and self.finetuning_type != "oft":
|
||||
raise ValueError("`reward_model_type` cannot be oft for Freeze/Full PPO training.")
|
||||
|
||||
if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
|
||||
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
||||
|
||||
|
@ -111,8 +111,8 @@ def _verify_model_args(
|
||||
raise ValueError("Adapter is only valid for the LoRA method.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
if finetuning_args.finetuning_type not in ["lora", "oft"]:
|
||||
raise ValueError("Quantization is only compatible with the LoRA or OFT method.")
|
||||
|
||||
if finetuning_args.pissa_init:
|
||||
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
|
||||
|
@ -16,10 +16,11 @@ import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from peft import LoraConfig, LoraModel, OFTConfig, OFTModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import check_version
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
@ -147,6 +148,9 @@ def _setup_lora_tuning(
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> "PeftModel":
|
||||
if is_trainable:
|
||||
if finetuning_args.finetuning_type == "oft":
|
||||
logger.info_rank0("Fine-tuning method: OFT")
|
||||
else:
|
||||
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
|
||||
adapter_to_resume = None
|
||||
@ -223,6 +227,7 @@ def _setup_lora_tuning(
|
||||
finetuning_args.additional_target = module_names
|
||||
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
"target_modules": target_modules,
|
||||
@ -232,8 +237,19 @@ def _setup_lora_tuning(
|
||||
"use_dora": finetuning_args.use_dora,
|
||||
"modules_to_save": finetuning_args.additional_target,
|
||||
}
|
||||
elif finetuning_args.finetuning_type == "oft":
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.oft_rank,
|
||||
"oft_block_size": finetuning_args.oft_block_size,
|
||||
"target_modules": target_modules,
|
||||
"module_dropout": finetuning_args.module_dropout,
|
||||
"modules_to_save": finetuning_args.additional_target,
|
||||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
if finetuning_args.finetuning_type == "oft":
|
||||
raise ValueError("Unsloth is currently not supported for OFT.")
|
||||
|
||||
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
|
||||
else:
|
||||
if finetuning_args.pissa_init:
|
||||
@ -244,12 +260,19 @@ def _setup_lora_tuning(
|
||||
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
|
||||
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
|
||||
|
||||
lora_config = LoraConfig(
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
elif finetuning_args.finetuning_type == "oft":
|
||||
peft_config = OFTConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
if is_trainable and cast_trainable_params_to_fp32:
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
@ -272,8 +295,8 @@ def init_adapter(
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
if is_trainable and getattr(model, "quantization_method", None) is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantized models can only be used for the LoRA tuning.")
|
||||
if finetuning_args.finetuning_type not in ["lora", "oft"]:
|
||||
raise ValueError("Quantized models can only be used for the LoRA or OFT tuning.")
|
||||
|
||||
if finetuning_args.pissa_init:
|
||||
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
|
||||
@ -296,7 +319,7 @@ def init_adapter(
|
||||
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
elif finetuning_args.finetuning_type == "freeze":
|
||||
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
elif finetuning_args.finetuning_type == "lora":
|
||||
elif finetuning_args.finetuning_type in ["lora", "oft"]:
|
||||
model = _setup_lora_tuning(
|
||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||
)
|
||||
|
@ -390,7 +390,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
|
||||
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
if self.finetuning_args.reward_model_type in ["lora", "oft"]:
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
reward_model = self.model
|
||||
else:
|
||||
@ -399,7 +399,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
|
||||
values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
if self.finetuning_args.reward_model_type in ["lora", "oft"]:
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
|
Loading…
x
Reference in New Issue
Block a user