From a769fb94b93f39c6859303b44546eca46a6f8c29 Mon Sep 17 00:00:00 2001 From: mrhaoxx Date: Thu, 18 Dec 2025 21:26:25 +0800 Subject: [PATCH] [feat] support ktransformers for dpo (#9621) Co-authored-by: poryfly --- src/llamafactory/train/dpo/ktrainer.py | 71 ++++++++++++++++++++++++++ src/llamafactory/train/dpo/workflow.py | 11 +++- 2 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 src/llamafactory/train/dpo/ktrainer.py diff --git a/src/llamafactory/train/dpo/ktrainer.py b/src/llamafactory/train/dpo/ktrainer.py new file mode 100644 index 000000000..d638d8890 --- /dev/null +++ b/src/llamafactory/train/dpo/ktrainer.py @@ -0,0 +1,71 @@ +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Optional +import torch +import torch.nn.functional as F +from transformers import Trainer +from trl import DPOTrainer +from trl.trainer import disable_dropout_in_model +from typing_extensions import override + +from ...extras.constants import IGNORE_INDEX +from ...extras.packages import is_transformers_version_greater_than +from ..callbacks import SaveProcessorCallback +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach +from .trainer import CustomDPOTrainer as BaseDPOTrainer +from ktransformers.sft.lora import KTrainer + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, ProcessorMixin + + from ...hparams import FinetuningArguments + + +class CustomDPOTrainer(KTrainer, BaseDPOTrainer): + @override + def concatenated_forward( + self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False + ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. + + Otherwise the average log probabilities. + """ + if self.finetuning_args.use_ref_model: + batch = nested_detach(batch, clone=True) # avoid error + labels = batch["labels"] + # dpo not need compute loss in forward, waste mem + del batch["labels"] + all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) + all_logits = all_logits.to("cpu") + labels = labels.to(all_logits.device) + all_logps, valid_length = get_batch_logps( + logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None) + ) + if self.loss_type in ["ipo", "orpo", "simpo"]: + all_logps = all_logps / valid_length + + batch_size = batch["input_ids"].size(0) // 2 + chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) + chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) + chosen_length, _ = valid_length.split(batch_size, dim=0) + + if self.loss_type in ["ipo", "orpo", "simpo"]: + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps + else: + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length \ No newline at end of file diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index c0a107d2f..4e3c8f8f3 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -24,7 +24,6 @@ from ...extras.ploting import plot_loss from ...hparams import ModelArguments from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push, create_ref_model -from .trainer import CustomDPOTrainer if TYPE_CHECKING: @@ -62,6 +61,16 @@ def run_dpo( ref_model = create_ref_model(model_args, finetuning_args) else: ref_model = None + + + if model_args.use_kt: + from ktransformers.util.globals import GLOBAL_CONFIG + + GLOBAL_CONFIG._config["mod"] = "sft" + + from .ktrainer import CustomDPOTrainer + else: + from .trainer import CustomDPOTrainer # Initialize our Trainer trainer = CustomDPOTrainer(