[misc] fix env vars (#7715)

This commit is contained in:
hoshi-hiyouga 2025-04-14 16:04:04 +08:00 committed by GitHub
parent 3ef36d0057
commit 3a13d2cdb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 3 deletions

View File

@ -16,6 +16,8 @@ USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB=
USE_RAY=
RECORD_VRAM=
OPTIM_TORCH=
NPU_JIT_COMPILE=
# torchrun
FORCE_TORCHRUN=
MASTER_ADDR=

View File

@ -177,10 +177,10 @@ def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes)."""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
elif is_torch_xpu_available():
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else:
return 0, 0

View File

@ -97,7 +97,7 @@ def patch_config(
if is_torch_npu_available():
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch.npu.set_compile_mode(jit_compile=is_env_enabled("JIT_COMPILE"))
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)