mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
support llama pro #2338 , add rslora
This commit is contained in:
@@ -10,6 +10,7 @@ from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
)
|
||||
@@ -133,6 +134,8 @@ def get_current_device() -> torch.device:
|
||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_mps_available():
|
||||
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_cuda_available():
|
||||
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user