mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
remove .cpu()
Former-commit-id: 35c57cc9dcba305d40282a9757ddc23968c210ac
This commit is contained in:
parent
a7fbae47d5
commit
6579ec8c4c
@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
|
||||
if logits.dim() != 3:
|
||||
raise ValueError("Cannot process the logits.")
|
||||
|
||||
return torch.argmax(logits, dim=-1).cpu()
|
||||
return torch.argmax(logits, dim=-1)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
Loading…
x
Reference in New Issue
Block a user