remove .cpu()

Former-commit-id: a6c6750e8af5bc1ece1dfe6111d3e484fd19ee75
This commit is contained in:
hoshi-hiyouga 2024-09-02 10:10:53 +08:00 committed by GitHub
parent b2a5f49a24
commit 5c9972a2d5

View File

@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
if logits.dim() != 3:
raise ValueError("Cannot process the logits.")
return torch.argmax(logits, dim=-1).cpu()
return torch.argmax(logits, dim=-1)
@dataclass