upcast logits

This commit is contained in:
hiyouga
2024-07-02 22:32:05 +08:00
parent c47ab6c072
commit c13ae2df19
2 changed files with 2 additions and 2 deletions

View File

@@ -99,7 +99,7 @@ class PairwiseTrainer(Trainer):
chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
loss = -torch.nn.functional.logsigmoid(chosen_scores - rejected_scores).mean()
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if return_outputs:
return loss, (loss, chosen_scores, rejected_scores)
else: