From 5c9972a2d570fa64842d38dade57a54d650dac15 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 2 Sep 2024 10:10:53 +0800 Subject: [PATCH] remove .cpu() Former-commit-id: a6c6750e8af5bc1ece1dfe6111d3e484fd19ee75 --- src/llamafactory/train/sft/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index 47657b75..69327379 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor if logits.dim() != 3: raise ValueError("Cannot process the logits.") - return torch.argmax(logits, dim=-1).cpu() + return torch.argmax(logits, dim=-1) @dataclass