From 6579ec8c4cd4f436bda7ee9990c0139130444f0a Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 2 Sep 2024 10:10:53 +0800 Subject: [PATCH] remove .cpu() Former-commit-id: 35c57cc9dcba305d40282a9757ddc23968c210ac --- src/llamafactory/train/sft/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index 47657b75..69327379 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor if logits.dim() != 3: raise ValueError("Cannot process the logits.") - return torch.argmax(logits, dim=-1).cpu() + return torch.argmax(logits, dim=-1) @dataclass