Former-commit-id: 4a6ca621c0
This commit is contained in:
hiyouga
2024-04-01 22:53:52 +08:00
parent 8d987b7af7
commit 829cf6458a
4 changed files with 23 additions and 15 deletions

View File

@@ -73,7 +73,7 @@ class CustomORPOTrainer(DPOTrainer):
Computes the average log probabilities of the labels under the given logits.
"""
all_logits: "torch.Tensor" = model(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
).logits.to(torch.float32)
all_logps = self.get_batch_logps(