mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
[trainer] Add LD-DPO objective (#8362)
This commit is contained in:
parent
44f1b9b5ad
commit
8e4ac78607
@ -202,6 +202,10 @@ class RLHFArguments:
|
|||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||||
)
|
)
|
||||||
|
ld_alpha: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token log-probabilities in responses"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -80,6 +80,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
self.ftx_gamma = finetuning_args.pref_ftx
|
self.ftx_gamma = finetuning_args.pref_ftx
|
||||||
self.label_smoothing = finetuning_args.dpo_label_smoothing
|
self.label_smoothing = finetuning_args.dpo_label_smoothing
|
||||||
self.simpo_gamma = finetuning_args.simpo_gamma
|
self.simpo_gamma = finetuning_args.simpo_gamma
|
||||||
|
self.ld_alpha = finetuning_args.ld_alpha
|
||||||
|
|
||||||
Trainer.__init__(self, model=model, **kwargs)
|
Trainer.__init__(self, model=model, **kwargs)
|
||||||
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
||||||
@ -177,7 +178,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
|
||||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||||
|
|
||||||
@ -187,7 +188,8 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
batch = nested_detach(batch, clone=True) # avoid error
|
batch = nested_detach(batch, clone=True) # avoid error
|
||||||
|
|
||||||
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"],
|
||||||
|
ld_alpha=(self.ld_alpha if not is_ref_model else None))
|
||||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||||
all_logps = all_logps / valid_length
|
all_logps = all_logps / valid_length
|
||||||
|
|
||||||
@ -217,7 +219,8 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
ref_context = nullcontext()
|
ref_context = nullcontext()
|
||||||
|
|
||||||
with torch.no_grad(), ref_context:
|
with torch.no_grad(), ref_context:
|
||||||
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
|
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch,
|
||||||
|
is_ref_model=True)
|
||||||
|
|
||||||
return reference_chosen_logps, reference_rejected_logps
|
return reference_chosen_logps, reference_rejected_logps
|
||||||
|
|
||||||
|
@ -585,7 +585,7 @@ def create_custom_scheduler(
|
|||||||
|
|
||||||
|
|
||||||
def get_batch_logps(
|
def get_batch_logps(
|
||||||
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
|
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX, ld_alpha: Optional[float] = None
|
||||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||||
r"""Compute the log probabilities of the given labels under the given logits.
|
r"""Compute the log probabilities of the given labels under the given logits.
|
||||||
|
|
||||||
@ -602,7 +602,30 @@ def get_batch_logps(
|
|||||||
loss_mask = labels != label_pad_token_id
|
loss_mask = labels != label_pad_token_id
|
||||||
labels[labels == label_pad_token_id] = 0 # dummy token
|
labels[labels == label_pad_token_id] = 0 # dummy token
|
||||||
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
||||||
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
|
|
||||||
|
valid_length = loss_mask.sum(-1)
|
||||||
|
if ld_alpha is not None:
|
||||||
|
num_examples = labels.shape[0] // 2
|
||||||
|
chosen_lengths = valid_length[:num_examples]
|
||||||
|
rejected_lengths = valid_length[num_examples:]
|
||||||
|
min_lengths = torch.min(chosen_lengths, rejected_lengths)
|
||||||
|
start_positions = torch.argmax(loss_mask.int(), dim=1)
|
||||||
|
public_lengths = start_positions + torch.cat([min_lengths, min_lengths], dim=0)
|
||||||
|
|
||||||
|
seq_len = labels.shape[-1]
|
||||||
|
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
|
||||||
|
|
||||||
|
ld_mask = position_ids < public_lengths.unsqueeze(1)
|
||||||
|
front_mask = (ld_mask * loss_mask).float()
|
||||||
|
rear_mask = (~ld_mask * loss_mask).float()
|
||||||
|
|
||||||
|
front_logps = (per_token_logps * front_mask).sum(-1)
|
||||||
|
rear_logps = (per_token_logps * rear_mask).sum(-1)
|
||||||
|
logps = front_logps + ld_alpha * rear_logps
|
||||||
|
else:
|
||||||
|
logps = (per_token_logps * loss_mask).sum(-1)
|
||||||
|
|
||||||
|
return logps, valid_length
|
||||||
|
|
||||||
|
|
||||||
def nested_detach(
|
def nested_detach(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user