mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[misc] fix cuda warn on intel GPU (#7655)
This commit is contained in:
parent
34fdabe005
commit
3bdc7e1e6c
@ -179,6 +179,8 @@ def get_peak_memory() -> tuple[int, int]:
|
||||
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
|
||||
elif is_torch_xpu_available():
|
||||
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
|
||||
else:
|
||||
return 0, 0
|
||||
|
||||
@ -200,7 +202,7 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
|
||||
def is_gpu_or_npu_available() -> bool:
|
||||
r"""Check if the GPU or NPU is available."""
|
||||
return is_torch_npu_available() or is_torch_cuda_available()
|
||||
return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available()
|
||||
|
||||
|
||||
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
||||
|
Loading…
x
Reference in New Issue
Block a user