diff --git a/README.md b/README.md index 6342c2e0..1fd4d59f 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ Choose your path: - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. -- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. +- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT] (https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA. - **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc. @@ -329,16 +329,16 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t ## Supported Training Approaches -| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA | -| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | -| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA | OFT | QOFT | +| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | +| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | > [!TIP] > The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html). @@ -469,14 +469,14 @@ huggingface-cli login \* *estimated* -| Method | Bits | 7B | 14B | 30B | 70B | `x`B | -| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- | -| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB | -| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB | -| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB | -| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB | -| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB | -| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB | +| Method | Bits | 7B | 14B | 30B | 70B | `x`B | +| ----------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- | +| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB | +| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB | +| Freeze/LoRA/GaLore/APOLLO/BAdam/OFT | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB | +| QLoRA / QOFT | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB | +| QLoRA / QOFT | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB | +| QLoRA / QOFT | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB | ## Getting Started diff --git a/examples/README.md b/examples/README.md index 1e4cd2df..3fa7b1d1 100644 --- a/examples/README.md +++ b/examples/README.md @@ -290,3 +290,15 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ```bash bash examples/extras/fsdp_qlora/train.sh ``` + +#### OFT Fine-Tuning + +```bash +llamafactory-cli train examples/extras/oft/llama3_oft_sft.yaml +``` + +#### QOFT Fine-Tuning + +```bash +llamafactory-cli train examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml +``` diff --git a/examples/README_zh.md b/examples/README_zh.md index 1cb49a7d..aa42e491 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -290,3 +290,15 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ```bash bash examples/extras/fsdp_qlora/train.sh ``` + +#### OFT 微调 + +```bash +llamafactory-cli train examples/extras/oft/llama3_oft_sft.yaml +``` + +#### QOFT 微调 + +```bash +llamafactory-cli train examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml +``` diff --git a/examples/extras/oft/llama3_oft_sft.yaml b/examples/extras/oft/llama3_oft_sft.yaml new file mode 100644 index 00000000..ae027c40 --- /dev/null +++ b/examples/extras/oft/llama3_oft_sft.yaml @@ -0,0 +1,46 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: oft +oft_block_size: 32 +oft_target: all + +### dataset +dataset: identity,alpaca_en_demo +template: llama3 +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/llama3-8b/oft/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 +resume_from_checkpoint: null + +### eval +# eval_dataset: alpaca_en_demo +# val_size: 0.1 +# per_device_eval_batch_size: 1 +# eval_strategy: steps +# eval_steps: 500 diff --git a/examples/extras/oft/qwen2_5vl_oft_sft.yaml b/examples/extras/oft/qwen2_5vl_oft_sft.yaml new file mode 100644 index 00000000..a688bd52 --- /dev/null +++ b/examples/extras/oft/qwen2_5vl_oft_sft.yaml @@ -0,0 +1,47 @@ +### model +model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct +image_max_pixels: 262144 +video_max_pixels: 16384 +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: oft +oft_block_size: 32 +oft_target: all + +### dataset +dataset: mllm_demo,identity,alpaca_en_demo # video: mllm_video_demo +template: qwen2_vl +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/qwen2_5vl-7b/oft/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 +resume_from_checkpoint: null + +### eval +# val_size: 0.1 +# per_device_eval_batch_size: 1 +# eval_strategy: steps +# eval_steps: 500 diff --git a/examples/extras/qoft/llama3_oft_sft_awq.yaml b/examples/extras/qoft/llama3_oft_sft_awq.yaml new file mode 100644 index 00000000..37ebf25a --- /dev/null +++ b/examples/extras/qoft/llama3_oft_sft_awq.yaml @@ -0,0 +1,44 @@ +### model +model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: oft +oft_block_size: 32 +oft_target: all + +### dataset +dataset: identity,alpaca_en_demo +template: llama3 +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/llama3-8b/oft/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +# val_size: 0.1 +# per_device_eval_batch_size: 1 +# eval_strategy: steps +# eval_steps: 500 diff --git a/examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml b/examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml new file mode 100644 index 00000000..5d57a6de --- /dev/null +++ b/examples/extras/qoft/llama3_oft_sft_bnb_npu.yaml @@ -0,0 +1,47 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +quantization_bit: 4 +quantization_method: bnb +double_quantization: false +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: oft +oft_block_size: 32 +oft_target: all + +### dataset +dataset: identity,alpaca_en_demo +template: llama3 +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/llama3-8b/oft/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +# val_size: 0.1 +# per_device_eval_batch_size: 1 +# eval_strategy: steps +# eval_steps: 500 diff --git a/examples/extras/qoft/llama3_oft_sft_gptq.yaml b/examples/extras/qoft/llama3_oft_sft_gptq.yaml new file mode 100644 index 00000000..3c098726 --- /dev/null +++ b/examples/extras/qoft/llama3_oft_sft_gptq.yaml @@ -0,0 +1,44 @@ +### model +model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: oft +oft_block_size: 32 +oft_target: all + +### dataset +dataset: identity,alpaca_en_demo +template: llama3 +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/llama3-8b/oft/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +# val_size: 0.1 +# per_device_eval_batch_size: 1 +# eval_strategy: steps +# eval_steps: 500 diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 586cd3e9..3bec8474 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -56,13 +56,13 @@ LAYERNORM_NAMES = {"norm", "ln"} LLAMABOARD_CONFIG = "llamaboard_config.yaml" -METHODS = ["full", "freeze", "lora"] +METHODS = ["full", "freeze", "lora", "oft"] MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} MULTIMODAL_SUPPORTED_MODELS = set() -PEFT_METHODS = {"lora"} +PEFT_METHODS = {"lora", "oft"} RUNNING_LOG = "running_log.txt" diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 21cf30a1..3130f86e 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -122,6 +122,48 @@ class LoraArguments: ) +@dataclass +class OFTArguments: + r"""Arguments pertaining to the OFT training.""" + + additional_target: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Name(s) of modules apart from LoRA layers to be set as trainable " + "and saved in the final checkpoint. " + "Use commas to separate multiple modules." + ) + }, + ) + module_dropout: float = field( + default=0.0, + metadata={"help": "Dropout rate for the OFT fine-tuning."}, + ) + oft_rank: int = field( + default=0, + metadata={"help": "The intrinsic dimension for OFT fine-tuning."}, + ) + oft_block_size: int = field( + default=32, + metadata={"help": "The intrinsic dimension for OFT fine-tuning."}, + ) + oft_target: str = field( + default="all", + metadata={ + "help": ( + "Name(s) of target modules to apply OFT. " + "Use commas to separate multiple modules. " + "Use `all` to specify all the linear modules." + ) + }, + ) + create_new_adapter: bool = field( + default=False, + metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, + ) + + @dataclass class RLHFArguments: r"""Arguments pertaining to the PPO, DPO and KTO training.""" @@ -400,7 +442,14 @@ class SwanLabArguments: @dataclass class FinetuningArguments( - SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments + SwanLabArguments, + BAdamArgument, + ApolloArguments, + GaloreArguments, + RLHFArguments, + LoraArguments, + OFTArguments, + FreezeArguments, ): r"""Arguments pertaining to which techniques we are going to fine-tuning with.""" @@ -475,12 +524,13 @@ class FinetuningArguments( self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules) self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 self.lora_target: list[str] = split_arg(self.lora_target) + self.oft_target: list[str] = split_arg(self.oft_target) self.additional_target: Optional[list[str]] = split_arg(self.additional_target) self.galore_target: list[str] = split_arg(self.galore_target) self.apollo_target: list[str] = split_arg(self.apollo_target) self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] - assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." + assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." @@ -490,6 +540,9 @@ class FinetuningArguments( if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.") + if self.stage == "ppo" and self.reward_model_type == "oft" and self.finetuning_type != "oft": + raise ValueError("`reward_model_type` cannot be oft for Freeze/Full PPO training.") + if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6: raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 9b43198b..0f4e632f 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -111,8 +111,8 @@ def _verify_model_args( raise ValueError("Adapter is only valid for the LoRA method.") if model_args.quantization_bit is not None: - if finetuning_args.finetuning_type != "lora": - raise ValueError("Quantization is only compatible with the LoRA method.") + if finetuning_args.finetuning_type not in ["lora", "oft"]: + raise ValueError("Quantization is only compatible with the LoRA or OFT method.") if finetuning_args.pissa_init: raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.") diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 9a000f41..07573266 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -16,10 +16,11 @@ import re from typing import TYPE_CHECKING import torch -from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model +from peft import LoraConfig, LoraModel, OFTConfig, OFTModel, PeftModel, TaskType, get_peft_model from transformers.integrations import is_deepspeed_zero3_enabled from ..extras import logging +from ..extras.misc import check_version from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.quantization import QuantizationMethod from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model @@ -147,7 +148,10 @@ def _setup_lora_tuning( cast_trainable_params_to_fp32: bool, ) -> "PeftModel": if is_trainable: - logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) + if finetuning_args.finetuning_type == "oft": + logger.info_rank0("Fine-tuning method: OFT") + else: + logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) adapter_to_resume = None @@ -223,17 +227,29 @@ def _setup_lora_tuning( finetuning_args.additional_target = module_names logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) - peft_kwargs = { - "r": finetuning_args.lora_rank, - "target_modules": target_modules, - "lora_alpha": finetuning_args.lora_alpha, - "lora_dropout": finetuning_args.lora_dropout, - "use_rslora": finetuning_args.use_rslora, - "use_dora": finetuning_args.use_dora, - "modules_to_save": finetuning_args.additional_target, - } + if finetuning_args.finetuning_type == "lora": + peft_kwargs = { + "r": finetuning_args.lora_rank, + "target_modules": target_modules, + "lora_alpha": finetuning_args.lora_alpha, + "lora_dropout": finetuning_args.lora_dropout, + "use_rslora": finetuning_args.use_rslora, + "use_dora": finetuning_args.use_dora, + "modules_to_save": finetuning_args.additional_target, + } + elif finetuning_args.finetuning_type == "oft": + peft_kwargs = { + "r": finetuning_args.oft_rank, + "oft_block_size": finetuning_args.oft_block_size, + "target_modules": target_modules, + "module_dropout": finetuning_args.module_dropout, + "modules_to_save": finetuning_args.additional_target, + } if model_args.use_unsloth: + if finetuning_args.finetuning_type == "oft": + raise ValueError("Unsloth is currently not supported for OFT.") + model = get_unsloth_peft_model(model, model_args, peft_kwargs) else: if finetuning_args.pissa_init: @@ -244,12 +260,19 @@ def _setup_lora_tuning( logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.") peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}" - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - **peft_kwargs, - ) - model = get_peft_model(model, lora_config) + if finetuning_args.finetuning_type == "lora": + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + **peft_kwargs, + ) + elif finetuning_args.finetuning_type == "oft": + peft_config = OFTConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + **peft_kwargs, + ) + model = get_peft_model(model, peft_config) if is_trainable and cast_trainable_params_to_fp32: for param in filter(lambda p: p.requires_grad, model.parameters()): @@ -272,8 +295,8 @@ def init_adapter( Note that the trainable parameters must be cast to float32. """ if is_trainable and getattr(model, "quantization_method", None) is not None: - if finetuning_args.finetuning_type != "lora": - raise ValueError("Quantized models can only be used for the LoRA tuning.") + if finetuning_args.finetuning_type not in ["lora", "oft"]: + raise ValueError("Quantized models can only be used for the LoRA or OFT tuning.") if finetuning_args.pissa_init: raise ValueError("Cannot initialize PiSSA adapter on quantized models.") @@ -296,7 +319,7 @@ def init_adapter( _setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32) elif finetuning_args.finetuning_type == "freeze": _setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32) - elif finetuning_args.finetuning_type == "lora": + elif finetuning_args.finetuning_type in ["lora", "oft"]: model = _setup_lora_tuning( config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 ) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 14497459..09d12a85 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -390,7 +390,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses) unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model) - if self.finetuning_args.reward_model_type == "lora": + if self.finetuning_args.reward_model_type in ["lora", "oft"]: replace_model(unwrapped_model, target="reward") reward_model = self.model else: @@ -399,7 +399,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16 values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1] - if self.finetuning_args.reward_model_type == "lora": + if self.finetuning_args.reward_model_type in ["lora", "oft"]: replace_model(unwrapped_model, target="default") rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))