mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
remove .cpu()
Former-commit-id: a6c6750e8af5bc1ece1dfe6111d3e484fd19ee75
This commit is contained in:
parent
b2a5f49a24
commit
5c9972a2d5
@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
|
|||||||
if logits.dim() != 3:
|
if logits.dim() != 3:
|
||||||
raise ValueError("Cannot process the logits.")
|
raise ValueError("Cannot process the logits.")
|
||||||
|
|
||||||
return torch.argmax(logits, dim=-1).cpu()
|
return torch.argmax(logits, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user