diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index c1395552..ac7b2b80 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -141,6 +141,8 @@ def get_device_count() -> int: return torch.npu.device_count() elif is_torch_cuda_available(): return torch.cuda.device_count() + elif is_torch_xpu_available(): + return torch.xpu.device_count() else: return 0