diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index 7582e16f..11727420 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -87,16 +87,22 @@ class CustomDPOTrainer(DPOTrainer): def concatenated_forward( self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r""" + Computes the sum log probabilities of the labels under the given logits if loss_type != IPO. + + Otherwise the average log probabilities. + """ batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error - all_logits = model( + all_logits: "torch.Tensor" = model( input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True ).logits.to(torch.float32) all_logps = self.get_batch_logps( - all_logits, - batch["labels"], - average_log_prob=False, + logits=all_logits, + labels=batch_copied["labels"], + average_log_prob=(self.loss_type == "ipo"), + is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) batch_size = batch["input_ids"].size(0) // 2 diff --git a/src/llmtuner/train/orpo/trainer.py b/src/llmtuner/train/orpo/trainer.py index af34b55e..50b999f8 100644 --- a/src/llmtuner/train/orpo/trainer.py +++ b/src/llmtuner/train/orpo/trainer.py @@ -56,55 +56,31 @@ class CustomORPOTrainer(DPOTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) - def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor": - r""" - Computes supervised cross-entropy loss of given labels under the given logits. - - Returns: - A tensor of shape (batch_size,) containing the cross-entropy loss of each samples. - """ - all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True) - return -all_logps - - # Borrowed from: - # https://github.com/huggingface/trl/blob/0ee349dcd43b0f4b3169449f16751c38ac4a609f/trl/trainer/orpo_trainer.py#L592 - def odds_ratio_loss( - self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor" - ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": r""" Computes ORPO's odds ratio (OR) loss. - - Args: - policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) - policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) - - Returns: - A tuple of five tensors: (losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen). """ - - # Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) log_odds = (chosen_logps - rejected_logps) - ( torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps)) ) - ratio = F.logsigmoid(log_odds) - losses = self.beta * ratio - - chosen_rewards = self.beta * chosen_logps.detach() - rejected_rewards = self.beta * rejected_logps.detach() - - return losses, chosen_rewards, rejected_rewards, ratio, log_odds + odds_ratio_loss = -F.logsigmoid(log_odds) + return odds_ratio_loss def concatenated_forward( self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: - all_logits = model( + r""" + Computes the average log probabilities of the labels under the given logits. + """ + all_logits: "torch.Tensor" = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True ).logits.to(torch.float32) all_logps = self.get_batch_logps( - all_logits, - batch["labels"], - average_log_prob=False, + logits=all_logits, + labels=batch["labels"], + average_log_prob=True, + is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) batch_size = batch["input_ids"].size(0) // 2 @@ -123,15 +99,12 @@ class CustomORPOTrainer(DPOTrainer): """ metrics = {} chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch) + sft_loss = chosen_logps + odds_ratio_loss = self.odds_ratio_loss(chosen_logps, rejected_logps) + batch_loss = (sft_loss + self.beta * odds_ratio_loss).mean() - losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( - chosen_logps, rejected_logps - ) - batch_size = batch["input_ids"].size(0) // 2 - chosen_labels, _ = batch["labels"].split(batch_size, dim=0) - sft_loss = self.sft_loss(chosen_logits, chosen_labels) - batch_loss = (sft_loss - losses).mean() - + chosen_rewards = self.beta * chosen_logps.detach() + rejected_rewards = self.beta * rejected_logps.detach() reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" @@ -144,7 +117,6 @@ class CustomORPOTrainer(DPOTrainer): metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().cpu().mean() metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().cpu().mean() metrics["{}sft_loss".format(prefix)] = sft_loss.detach().cpu().mean() - metrics["{}log_odds_ratio".format(prefix)] = log_odds_ratio.detach().cpu().mean() - metrics["{}log_odds_chosen".format(prefix)] = log_odds_chosen.detach().cpu().mean() + metrics["{}odds_ratio_loss".format(prefix)] = odds_ratio_loss.detach().cpu().mean() return batch_loss, metrics