# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's TRL library. # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import TYPE_CHECKING import torch from ktransformers.sft.lora import KTrainer # type: ignore from typing_extensions import override from ..trainer_utils import get_batch_logps, nested_detach from .trainer import CustomDPOTrainer if TYPE_CHECKING: from transformers import PreTrainedModel class KDPOTrainer(KTrainer, CustomDPOTrainer): @override def concatenated_forward( self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. Otherwise the average log probabilities. """ if self.finetuning_args.use_ref_model: batch = nested_detach(batch, clone=True) # avoid error labels = batch.pop("labels") # dpo do not need compute loss in forward all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logits = all_logits.to("cpu") labels = labels.to(all_logits.device) all_logps, valid_length = get_batch_logps( logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None) ) if self.loss_type in ["ipo", "orpo", "simpo"]: all_logps = all_logps / valid_length batch_size = batch["input_ids"].size(0) // 2 chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) chosen_length, _ = valid_length.split(batch_size, dim=0) if self.loss_type in ["ipo", "orpo", "simpo"]: return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps else: return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length