From 3bdc7e1e6cee79bc3298f95822e506fffcc8b5d2 Mon Sep 17 00:00:00 2001 From: jilongW <109333127+jilongW@users.noreply.github.com> Date: Wed, 9 Apr 2025 21:37:54 +0800 Subject: [PATCH] [misc] fix cuda warn on intel GPU (#7655) --- src/llamafactory/extras/misc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 12f3304d..3c0f11f3 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -179,6 +179,8 @@ def get_peak_memory() -> tuple[int, int]: return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved() elif is_torch_cuda_available(): return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved() + elif is_torch_xpu_available(): + return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved() else: return 0, 0 @@ -200,7 +202,7 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": def is_gpu_or_npu_available() -> bool: r"""Check if the GPU or NPU is available.""" - return is_torch_npu_available() or is_torch_cuda_available() + return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available() def is_env_enabled(env_var: str, default: str = "0") -> bool: