Former-commit-id: 8c0f8357e1eebee32010fe715554f1136b68b4ba
This commit is contained in:
hiyouga 2024-07-15 22:32:07 +08:00
parent cb474c7b11
commit 1891b64072

View File

@ -55,7 +55,15 @@ def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
logits = logits[0] if isinstance(logits, (list, tuple)) else logits
if isinstance(logits, (list, tuple)):
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
logits = logits[0]
else: # moe models have aux loss
logits = logits[1]
if logits.dim() != 3:
raise ValueError("Cannot process the logits.")
return torch.argmax(logits, dim=-1)