fix IPO and ORPO loss

Former-commit-id: 5b9b40403d59982431a526e337f31d394f8b882b
This commit is contained in:
hiyouga 2024-04-01 14:37:53 +08:00
parent e7ade84bba
commit 69e1d39832
2 changed files with 27 additions and 49 deletions

View File

@ -87,16 +87,22 @@ class CustomDPOTrainer(DPOTrainer):
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes the sum log probabilities of the labels under the given logits if loss_type != IPO.
Otherwise the average log probabilities.
"""
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits = model(
all_logits: "torch.Tensor" = model(
input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
).logits.to(torch.float32)
all_logps = self.get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False,
logits=all_logits,
labels=batch_copied["labels"],
average_log_prob=(self.loss_type == "ipo"),
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
batch_size = batch["input_ids"].size(0) // 2

View File

@ -56,55 +56,31 @@ class CustomORPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -all_logps
# Borrowed from:
# https://github.com/huggingface/trl/blob/0ee349dcd43b0f4b3169449f16751c38ac4a609f/trl/trainer/orpo_trainer.py#L592
def odds_ratio_loss(
self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor"
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of five tensors: (losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen).
"""
# Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
)
ratio = F.logsigmoid(log_odds)
losses = self.beta * ratio
chosen_rewards = self.beta * chosen_logps.detach()
rejected_rewards = self.beta * rejected_logps.detach()
return losses, chosen_rewards, rejected_rewards, ratio, log_odds
odds_ratio_loss = -F.logsigmoid(log_odds)
return odds_ratio_loss
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
all_logits = model(
r"""
Computes the average log probabilities of the labels under the given logits.
"""
all_logits: "torch.Tensor" = model(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True
).logits.to(torch.float32)
all_logps = self.get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False,
logits=all_logits,
labels=batch["labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
batch_size = batch["input_ids"].size(0) // 2
@ -123,15 +99,12 @@ class CustomORPOTrainer(DPOTrainer):
"""
metrics = {}
chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
sft_loss = chosen_logps
odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps)
batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean()
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
chosen_logps, rejected_logps
)
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
sft_loss = self.sft_loss(chosen_logits, chosen_labels)
batch_loss = (sft_loss - losses).mean()
chosen_rewards = self.beta * chosen_logps.detach()
rejected_rewards = self.beta * rejected_logps.detach()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
@ -144,7 +117,6 @@ class CustomORPOTrainer(DPOTrainer):
metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().cpu().mean()
metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().cpu().mean()
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().cpu().mean()
metrics["{}log_odds_ratio".format(prefix)] = log_odds_ratio.detach().cpu().mean()
metrics["{}log_odds_chosen".format(prefix)] = log_odds_chosen.detach().cpu().mean()
metrics["{}odds_ratio_loss".format(prefix)] = odds_ratio_loss.detach().cpu().mean()
return batch_loss, metrics