mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
remove .cpu()
Former-commit-id: a6c6750e8af5bc1ece1dfe6111d3e484fd19ee75
This commit is contained in:
parent
b2a5f49a24
commit
5c9972a2d5
@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
|
||||
if logits.dim() != 3:
|
||||
raise ValueError("Cannot process the logits.")
|
||||
|
||||
return torch.argmax(logits, dim=-1).cpu()
|
||||
return torch.argmax(logits, dim=-1)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
Loading…
x
Reference in New Issue
Block a user