mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[trainer] Add LD-DPO objective (#8362)
This commit is contained in:
parent
5ed62a29c5
commit
1cfe42916d
@ -202,6 +202,10 @@ class RLHFArguments:
|
||||
default="lora",
|
||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||
)
|
||||
ld_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token log-probabilities in responses"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -80,6 +80,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.ftx_gamma = finetuning_args.pref_ftx
|
||||
self.label_smoothing = finetuning_args.dpo_label_smoothing
|
||||
self.simpo_gamma = finetuning_args.simpo_gamma
|
||||
self.ld_alpha = finetuning_args.ld_alpha
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
||||
@ -177,7 +178,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
|
||||
@ -187,7 +188,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
|
||||
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"],
|
||||
ld_alpha=(self.ld_alpha if not is_ref_model else None))
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
all_logps = all_logps / valid_length
|
||||
|
||||
@ -217,7 +219,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
ref_context = nullcontext()
|
||||
|
||||
with torch.no_grad(), ref_context:
|
||||
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
|
||||
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch,
|
||||
is_ref_model=True)
|
||||
|
||||
return reference_chosen_logps, reference_rejected_logps
|
||||
|
||||
|
@ -585,7 +585,7 @@ def create_custom_scheduler(
|
||||
|
||||
|
||||
def get_batch_logps(
|
||||
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
|
||||
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX, ld_alpha: Optional[float] = None
|
||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the log probabilities of the given labels under the given logits.
|
||||
|
||||
@ -602,7 +602,30 @@ def get_batch_logps(
|
||||
loss_mask = labels != label_pad_token_id
|
||||
labels[labels == label_pad_token_id] = 0 # dummy token
|
||||
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
||||
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
|
||||
|
||||
valid_length = loss_mask.sum(-1)
|
||||
if ld_alpha is not None:
|
||||
num_examples = labels.shape[0] // 2
|
||||
chosen_lengths = valid_length[:num_examples]
|
||||
rejected_lengths = valid_length[num_examples:]
|
||||
min_lengths = torch.min(chosen_lengths, rejected_lengths)
|
||||
start_positions = torch.argmax(loss_mask.int(), dim=1)
|
||||
public_lengths = start_positions + torch.cat([min_lengths, min_lengths], dim=0)
|
||||
|
||||
seq_len = labels.shape[-1]
|
||||
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
|
||||
|
||||
ld_mask = position_ids < public_lengths.unsqueeze(1)
|
||||
front_mask = (ld_mask * loss_mask).float()
|
||||
rear_mask = (~ld_mask * loss_mask).float()
|
||||
|
||||
front_logps = (per_token_logps * front_mask).sum(-1)
|
||||
rear_logps = (per_token_logps * rear_mask).sum(-1)
|
||||
logps = front_logps + ld_alpha * rear_logps
|
||||
else:
|
||||
logps = (per_token_logps * loss_mask).sum(-1)
|
||||
|
||||
return logps, valid_length
|
||||
|
||||
|
||||
def nested_detach(
|
||||
|
Loading…
x
Reference in New Issue
Block a user