remove .cpu()

Former-commit-id: 35c57cc9dcba305d40282a9757ddc23968c210ac
This commit is contained in:
hoshi-hiyouga 2024-09-02 10:10:53 +08:00 committed by GitHub
parent a7fbae47d5
commit 6579ec8c4c

View File

@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
if logits.dim() != 3:
raise ValueError("Cannot process the logits.")
return torch.argmax(logits, dim=-1).cpu()
return torch.argmax(logits, dim=-1)
@dataclass