optimize predict vram

This commit is contained in:
hiyouga
2024-08-30 23:08:45 +08:00
parent e08045a946
commit a244f143f4
5 changed files with 10 additions and 10 deletions

View File

@@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
if logits.dim() != 3:
raise ValueError("Cannot process the logits.")
return torch.argmax(logits, dim=-1)
return torch.argmax(logits, dim=-1).cpu()
@dataclass