update trainers

This commit is contained in:
hiyouga
2024-06-06 18:45:49 +08:00
parent 67aa78cde0
commit fad2591e31
4 changed files with 12 additions and 21 deletions

View File

@@ -187,13 +187,7 @@ class CustomDPOTrainer(DPOTrainer):
ref_context = nullcontext()
with torch.no_grad(), ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(ref_model, batch)
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
return reference_chosen_logps, reference_rejected_logps