mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
improve get_current_device
Former-commit-id: 40dfcbc3d4571ce022b6aa39db581c8b88a75b8d
This commit is contained in:
parent
3d291a82d3
commit
e400f2e8ad
@ -69,11 +69,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
|
|
||||||
def get_current_device() -> str:
|
def get_current_device() -> str:
|
||||||
import accelerate
|
import accelerate
|
||||||
dummy_accelerator = accelerate.Accelerator()
|
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
|
||||||
if accelerate.utils.is_xpu_available():
|
if accelerate.utils.is_xpu_available():
|
||||||
return "xpu:{}".format(dummy_accelerator.local_process_index)
|
return "xpu:{}".format(local_rank)
|
||||||
else:
|
else:
|
||||||
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
|
return local_rank if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> "LogitsProcessorList":
|
def get_logits_processor() -> "LogitsProcessorList":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user