remove .cpu()

This commit is contained in:
hoshi-hiyouga
2024-09-02 10:10:53 +08:00
committed by GitHub
parent 60fc6b926e
commit a6c6750e8a

View File

@@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
if logits.dim() != 3: if logits.dim() != 3:
raise ValueError("Cannot process the logits.") raise ValueError("Cannot process the logits.")
return torch.argmax(logits, dim=-1).cpu() return torch.argmax(logits, dim=-1)
@dataclass @dataclass