diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 12f3304d..3c0f11f3 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -179,6 +179,8 @@ def get_peak_memory() -> tuple[int, int]: return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved() elif is_torch_cuda_available(): return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved() + elif is_torch_xpu_available(): + return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved() else: return 0, 0 @@ -200,7 +202,7 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": def is_gpu_or_npu_available() -> bool: r"""Check if the GPU or NPU is available.""" - return is_torch_npu_available() or is_torch_cuda_available() + return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available() def is_env_enabled(env_var: str, default: str = "0") -> bool: