From 8e4ac786071f45de1a3b0e38d540ec17b52c75c8 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 12 Jun 2025 01:10:38 -0700 Subject: [PATCH] [trainer] Add LD-DPO objective (#8362) --- src/llamafactory/hparams/finetuning_args.py | 4 +++ src/llamafactory/train/dpo/trainer.py | 9 ++++--- src/llamafactory/train/trainer_utils.py | 27 +++++++++++++++++++-- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 6217015d..cd403e8c 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -202,6 +202,10 @@ class RLHFArguments: default="lora", metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, ) + ld_alpha: Optional[float] = field( + default=None, + metadata={"help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token log-probabilities in responses"}, + ) @dataclass diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 2539127c..44587264 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -80,6 +80,7 @@ class CustomDPOTrainer(DPOTrainer): self.ftx_gamma = finetuning_args.pref_ftx self.label_smoothing = finetuning_args.dpo_label_smoothing self.simpo_gamma = finetuning_args.simpo_gamma + self.ld_alpha = finetuning_args.ld_alpha Trainer.__init__(self, model=model, **kwargs) self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior @@ -177,7 +178,7 @@ class CustomDPOTrainer(DPOTrainer): @override def concatenated_forward( - self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"] + self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. @@ -187,7 +188,8 @@ class CustomDPOTrainer(DPOTrainer): batch = nested_detach(batch, clone=True) # avoid error all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) - all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) + all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"], + ld_alpha=(self.ld_alpha if not is_ref_model else None)) if self.loss_type in ["ipo", "orpo", "simpo"]: all_logps = all_logps / valid_length @@ -217,7 +219,8 @@ class CustomDPOTrainer(DPOTrainer): ref_context = nullcontext() with torch.no_grad(), ref_context: - reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch) + reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch, + is_ref_model=True) return reference_chosen_logps, reference_rejected_logps diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 87396220..688c7035 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -585,7 +585,7 @@ def create_custom_scheduler( def get_batch_logps( - logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX + logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX, ld_alpha: Optional[float] = None ) -> tuple["torch.Tensor", "torch.Tensor"]: r"""Compute the log probabilities of the given labels under the given logits. @@ -602,7 +602,30 @@ def get_batch_logps( loss_mask = labels != label_pad_token_id labels[labels == label_pad_token_id] = 0 # dummy token per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) + + valid_length = loss_mask.sum(-1) + if ld_alpha is not None: + num_examples = labels.shape[0] // 2 + chosen_lengths = valid_length[:num_examples] + rejected_lengths = valid_length[num_examples:] + min_lengths = torch.min(chosen_lengths, rejected_lengths) + start_positions = torch.argmax(loss_mask.int(), dim=1) + public_lengths = start_positions + torch.cat([min_lengths, min_lengths], dim=0) + + seq_len = labels.shape[-1] + position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) + + ld_mask = position_ids < public_lengths.unsqueeze(1) + front_mask = (ld_mask * loss_mask).float() + rear_mask = (~ld_mask * loss_mask).float() + + front_logps = (per_token_logps * front_mask).sum(-1) + rear_logps = (per_token_logps * rear_mask).sum(-1) + logps = front_logps + ld_alpha * rear_logps + else: + logps = (per_token_logps * loss_mask).sum(-1) + + return logps, valid_length def nested_detach(