[misc] fix cuda warn on intel GPU (#7655)

This commit is contained in:
jilongW 2025-04-09 21:37:54 +08:00 committed by GitHub
parent 34fdabe005
commit 3bdc7e1e6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -179,6 +179,8 @@ def get_peak_memory() -> tuple[int, int]:
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
elif is_torch_xpu_available():
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
else:
return 0, 0
@ -200,7 +202,7 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def is_gpu_or_npu_available() -> bool:
r"""Check if the GPU or NPU is available."""
return is_torch_npu_available() or is_torch_cuda_available()
return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available()
def is_env_enabled(env_var: str, default: str = "0") -> bool: