import torch from collections import defaultdict from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from transformers import BatchEncoding, Trainer from trl import DPOTrainer from trl.trainer.utils import disable_dropout_in_model from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.tuner.core.trainer import PeftModelMixin if TYPE_CHECKING: from transformers import PreTrainedModel from llmtuner.hparams import FinetuningArguments class DPOPeftTrainer(PeftModelMixin, DPOTrainer): def __init__( self, finetuning_args: "FinetuningArguments", model: Union["PreTrainedModel", torch.nn.Module], ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, disable_dropout: Optional[bool] = True, **kwargs ): if disable_dropout: disable_dropout_in_model(model) if ref_model is not None: disable_dropout_in_model(ref_model) self.finetuning_args = finetuning_args self.ref_model = ref_model self.use_dpo_data_collator = True # hack to avoid warning self.label_pad_token_id = IGNORE_INDEX self.padding_value = 0 self.beta = finetuning_args.dpo_beta self._stored_metrics = defaultdict(lambda: defaultdict(list)) Trainer.__init__(self, model=model, **kwargs) if not hasattr(self, "accelerator"): raise AttributeError("Please update `transformers`.") if ref_model is not None: if self.is_deepspeed_enabled: self.ref_model = self.accelerator._prepare_deepspeed(self.ref_model) self.ref_model.eval() else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) def concatenated_forward( self, model: Optional[torch.nn.Module] = None, batch: Optional[Dict[str, torch.Tensor]] = None ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error all_logits = model( input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True ).logits.to(torch.float32) all_logps = self._get_batch_logps( all_logits, batch["labels"], average_log_prob=False ) batch_size = batch["input_ids"].size(0) // 2 chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) return chosen_logps, rejected_logps, chosen_logits, rejected_logits