[feature] adding orthogononal finetuning (OFT) to llama factory (#8623)

Co-authored-by: Zeju <zqiu@g003.internal.cluster.is.localnet>
Co-authored-by: Zeju <zqiu@login2.is.localnet>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Zeju Qiu 2025-08-18 12:22:47 +02:00 committed by GitHub
parent 0454c10456
commit 842595698b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 375 additions and 47 deletions

View File

@ -86,7 +86,7 @@ Choose your path:
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. - **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT] (https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc. - **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
@ -329,16 +329,16 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
## Supported Training Approaches ## Supported Training Approaches
| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA | | Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA | OFT | QOFT |
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!TIP] > [!TIP]
> The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html). > The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
@ -469,14 +469,14 @@ huggingface-cli login
\* *estimated* \* *estimated*
| Method | Bits | 7B | 14B | 30B | 70B | `x`B | | Method | Bits | 7B | 14B | 30B | 70B | `x`B |
| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- | | ----------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB | | Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB | | Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB | | Freeze/LoRA/GaLore/APOLLO/BAdam/OFT | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB | | QLoRA / QOFT | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB | | QLoRA / QOFT | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB | | QLoRA / QOFT | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
## Getting Started ## Getting Started

View File

@ -290,3 +290,15 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```bash ```bash
bash examples/extras/fsdp_qlora/train.sh bash examples/extras/fsdp_qlora/train.sh
``` ```
#### OFT Fine-Tuning
```bash
llamafactory-cli train examples/extras/oft/llama3_oft_sft.yaml
```
#### QOFT Fine-Tuning
```bash
llamafactory-cli train examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml
```

View File

@ -290,3 +290,15 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```bash ```bash
bash examples/extras/fsdp_qlora/train.sh bash examples/extras/fsdp_qlora/train.sh
``` ```
#### OFT 微调
```bash
llamafactory-cli train examples/extras/oft/llama3_oft_sft.yaml
```
#### QOFT 微调
```bash
llamafactory-cli train examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml
```

View File

@ -0,0 +1,46 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: oft
oft_block_size: 32
oft_target: all
### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/oft/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@ -0,0 +1,47 @@
### model
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: oft
oft_block_size: 32
oft_target: all
### dataset
dataset: mllm_demo,identity,alpaca_en_demo # video: mllm_video_demo
template: qwen2_vl
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen2_5vl-7b/oft/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@ -0,0 +1,44 @@
### model
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: oft
oft_block_size: 32
oft_target: all
### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/oft/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@ -0,0 +1,47 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4
quantization_method: bnb
double_quantization: false
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: oft
oft_block_size: 32
oft_target: all
### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/oft/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@ -0,0 +1,44 @@
### model
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: oft
oft_block_size: 32
oft_target: all
### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/oft/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@ -56,13 +56,13 @@ LAYERNORM_NAMES = {"norm", "ln"}
LLAMABOARD_CONFIG = "llamaboard_config.yaml" LLAMABOARD_CONFIG = "llamaboard_config.yaml"
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora", "oft"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
MULTIMODAL_SUPPORTED_MODELS = set() MULTIMODAL_SUPPORTED_MODELS = set()
PEFT_METHODS = {"lora"} PEFT_METHODS = {"lora", "oft"}
RUNNING_LOG = "running_log.txt" RUNNING_LOG = "running_log.txt"

View File

@ -122,6 +122,48 @@ class LoraArguments:
) )
@dataclass
class OFTArguments:
r"""Arguments pertaining to the OFT training."""
additional_target: Optional[str] = field(
default=None,
metadata={
"help": (
"Name(s) of modules apart from LoRA layers to be set as trainable "
"and saved in the final checkpoint. "
"Use commas to separate multiple modules."
)
},
)
module_dropout: float = field(
default=0.0,
metadata={"help": "Dropout rate for the OFT fine-tuning."},
)
oft_rank: int = field(
default=0,
metadata={"help": "The intrinsic dimension for OFT fine-tuning."},
)
oft_block_size: int = field(
default=32,
metadata={"help": "The intrinsic dimension for OFT fine-tuning."},
)
oft_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of target modules to apply OFT. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
create_new_adapter: bool = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
)
@dataclass @dataclass
class RLHFArguments: class RLHFArguments:
r"""Arguments pertaining to the PPO, DPO and KTO training.""" r"""Arguments pertaining to the PPO, DPO and KTO training."""
@ -400,7 +442,14 @@ class SwanLabArguments:
@dataclass @dataclass
class FinetuningArguments( class FinetuningArguments(
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments SwanLabArguments,
BAdamArgument,
ApolloArguments,
GaloreArguments,
RLHFArguments,
LoraArguments,
OFTArguments,
FreezeArguments,
): ):
r"""Arguments pertaining to which techniques we are going to fine-tuning with.""" r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
@ -475,12 +524,13 @@ class FinetuningArguments(
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules) self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: list[str] = split_arg(self.lora_target) self.lora_target: list[str] = split_arg(self.lora_target)
self.oft_target: list[str] = split_arg(self.oft_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target) self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target) self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: list[str] = split_arg(self.apollo_target) self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
@ -490,6 +540,9 @@ class FinetuningArguments(
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.") raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.stage == "ppo" and self.reward_model_type == "oft" and self.finetuning_type != "oft":
raise ValueError("`reward_model_type` cannot be oft for Freeze/Full PPO training.")
if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6: if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.") raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")

View File

@ -111,8 +111,8 @@ def _verify_model_args(
raise ValueError("Adapter is only valid for the LoRA method.") raise ValueError("Adapter is only valid for the LoRA method.")
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type not in ["lora", "oft"]:
raise ValueError("Quantization is only compatible with the LoRA method.") raise ValueError("Quantization is only compatible with the LoRA or OFT method.")
if finetuning_args.pissa_init: if finetuning_args.pissa_init:
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.") raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")

View File

@ -16,10 +16,11 @@ import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model from peft import LoraConfig, LoraModel, OFTConfig, OFTModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging from ..extras import logging
from ..extras.misc import check_version
from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
@ -147,7 +148,10 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
) -> "PeftModel": ) -> "PeftModel":
if is_trainable: if is_trainable:
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) if finetuning_args.finetuning_type == "oft":
logger.info_rank0("Fine-tuning method: OFT")
else:
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None adapter_to_resume = None
@ -223,17 +227,29 @@ def _setup_lora_tuning(
finetuning_args.additional_target = module_names finetuning_args.additional_target = module_names
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = { if finetuning_args.finetuning_type == "lora":
"r": finetuning_args.lora_rank, peft_kwargs = {
"target_modules": target_modules, "r": finetuning_args.lora_rank,
"lora_alpha": finetuning_args.lora_alpha, "target_modules": target_modules,
"lora_dropout": finetuning_args.lora_dropout, "lora_alpha": finetuning_args.lora_alpha,
"use_rslora": finetuning_args.use_rslora, "lora_dropout": finetuning_args.lora_dropout,
"use_dora": finetuning_args.use_dora, "use_rslora": finetuning_args.use_rslora,
"modules_to_save": finetuning_args.additional_target, "use_dora": finetuning_args.use_dora,
} "modules_to_save": finetuning_args.additional_target,
}
elif finetuning_args.finetuning_type == "oft":
peft_kwargs = {
"r": finetuning_args.oft_rank,
"oft_block_size": finetuning_args.oft_block_size,
"target_modules": target_modules,
"module_dropout": finetuning_args.module_dropout,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth: if model_args.use_unsloth:
if finetuning_args.finetuning_type == "oft":
raise ValueError("Unsloth is currently not supported for OFT.")
model = get_unsloth_peft_model(model, model_args, peft_kwargs) model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else: else:
if finetuning_args.pissa_init: if finetuning_args.pissa_init:
@ -244,12 +260,19 @@ def _setup_lora_tuning(
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.") logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}" peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig( if finetuning_args.finetuning_type == "lora":
task_type=TaskType.CAUSAL_LM, peft_config = LoraConfig(
inference_mode=False, task_type=TaskType.CAUSAL_LM,
**peft_kwargs, inference_mode=False,
) **peft_kwargs,
model = get_peft_model(model, lora_config) )
elif finetuning_args.finetuning_type == "oft":
peft_config = OFTConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
model = get_peft_model(model, peft_config)
if is_trainable and cast_trainable_params_to_fp32: if is_trainable and cast_trainable_params_to_fp32:
for param in filter(lambda p: p.requires_grad, model.parameters()): for param in filter(lambda p: p.requires_grad, model.parameters()):
@ -272,8 +295,8 @@ def init_adapter(
Note that the trainable parameters must be cast to float32. Note that the trainable parameters must be cast to float32.
""" """
if is_trainable and getattr(model, "quantization_method", None) is not None: if is_trainable and getattr(model, "quantization_method", None) is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type not in ["lora", "oft"]:
raise ValueError("Quantized models can only be used for the LoRA tuning.") raise ValueError("Quantized models can only be used for the LoRA or OFT tuning.")
if finetuning_args.pissa_init: if finetuning_args.pissa_init:
raise ValueError("Cannot initialize PiSSA adapter on quantized models.") raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
@ -296,7 +319,7 @@ def init_adapter(
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32) _setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze": elif finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32) _setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "lora": elif finetuning_args.finetuning_type in ["lora", "oft"]:
model = _setup_lora_tuning( model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
) )

View File

@ -390,7 +390,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses) batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model) unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
if self.finetuning_args.reward_model_type == "lora": if self.finetuning_args.reward_model_type in ["lora", "oft"]:
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
reward_model = self.model reward_model = self.model
else: else:
@ -399,7 +399,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16 with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1] values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
if self.finetuning_args.reward_model_type == "lora": if self.finetuning_args.reward_model_type in ["lora", "oft"]:
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1)) rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))