mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
parent
ae0f4ba2d3
commit
84e6715423
@ -55,7 +55,15 @@ def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
|
||||
|
||||
|
||||
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
|
||||
logits = logits[0] if isinstance(logits, (list, tuple)) else logits
|
||||
if isinstance(logits, (list, tuple)):
|
||||
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
|
||||
logits = logits[0]
|
||||
else: # moe models have aux loss
|
||||
logits = logits[1]
|
||||
|
||||
if logits.dim() != 3:
|
||||
raise ValueError("Cannot process the logits.")
|
||||
|
||||
return torch.argmax(logits, dim=-1)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user