[trainer] Add LD-DPO objective (#8362)

This commit is contained in:
Aman Gupta 2025-06-12 01:10:38 -07:00 committed by GitHub
parent 5ed62a29c5
commit 1cfe42916d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 35 additions and 5 deletions

View File

@ -202,6 +202,10 @@ class RLHFArguments:
default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
)
ld_alpha: Optional[float] = field(
default=None,
metadata={"help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token log-probabilities in responses"},
)
@dataclass

View File

@ -80,6 +80,7 @@ class CustomDPOTrainer(DPOTrainer):
self.ftx_gamma = finetuning_args.pref_ftx
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.simpo_gamma = finetuning_args.simpo_gamma
self.ld_alpha = finetuning_args.ld_alpha
Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
@ -177,7 +178,7 @@ class CustomDPOTrainer(DPOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
@ -187,7 +188,8 @@ class CustomDPOTrainer(DPOTrainer):
batch = nested_detach(batch, clone=True) # avoid error
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"],
ld_alpha=(self.ld_alpha if not is_ref_model else None))
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
@ -217,7 +219,8 @@ class CustomDPOTrainer(DPOTrainer):
ref_context = nullcontext()
with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch,
is_ref_model=True)
return reference_chosen_logps, reference_rejected_logps

View File

@ -585,7 +585,7 @@ def create_custom_scheduler(
def get_batch_logps(
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX, ld_alpha: Optional[float] = None
) -> tuple["torch.Tensor", "torch.Tensor"]:
r"""Compute the log probabilities of the given labels under the given logits.
@ -602,7 +602,30 @@ def get_batch_logps(
loss_mask = labels != label_pad_token_id
labels[labels == label_pad_token_id] = 0 # dummy token
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
valid_length = loss_mask.sum(-1)
if ld_alpha is not None:
num_examples = labels.shape[0] // 2
chosen_lengths = valid_length[:num_examples]
rejected_lengths = valid_length[num_examples:]
min_lengths = torch.min(chosen_lengths, rejected_lengths)
start_positions = torch.argmax(loss_mask.int(), dim=1)
public_lengths = start_positions + torch.cat([min_lengths, min_lengths], dim=0)
seq_len = labels.shape[-1]
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
ld_mask = position_ids < public_lengths.unsqueeze(1)
front_mask = (ld_mask * loss_mask).float()
rear_mask = (~ld_mask * loss_mask).float()
front_logps = (per_token_logps * front_mask).sum(-1)
rear_logps = (per_token_logps * rear_mask).sum(-1)
logps = front_logps + ld_alpha * rear_logps
else:
logps = (per_token_logps * loss_mask).sum(-1)
return logps, valid_length
def nested_detach(