This commit is contained in:
hiyouga
2023-12-03 11:33:12 +08:00
parent 5b78e269b6
commit 03d05991f8
2 changed files with 4 additions and 15 deletions

View File

@@ -68,18 +68,6 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def get_current_device() -> str:
import accelerate
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif accelerate.utils.is_npu_available():
return "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif torch.cuda.is_available():
return "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
else:
return "cpu"
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.