use log1p in orpo loss

https://github.com/huggingface/trl/pull/1491

Former-commit-id: 68aaa4904b8dfb6cc791fdcee613edc681a8a198
This commit is contained in:
hiyouga 2024-03-31 19:27:08 +08:00
parent ddad9be81d
commit b873dcb09d

View File

@ -84,7 +84,7 @@ class CustomORPOTrainer(DPOTrainer):
# Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
log_odds = (chosen_logps - rejected_logps) - (
torch.log(1 - torch.exp(chosen_logps)) - torch.log(1 - torch.exp(rejected_logps))
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
)
ratio = F.logsigmoid(log_odds)
losses = self.beta * ratio