diff --git a/.env.local b/.env.local index 38a5503a..88ac8a46 100644 --- a/.env.local +++ b/.env.local @@ -16,6 +16,8 @@ USE_MODELSCOPE_HUB= USE_OPENMIND_HUB= USE_RAY= RECORD_VRAM= +OPTIM_TORCH= +NPU_JIT_COMPILE= # torchrun FORCE_TORCHRUN= MASTER_ADDR= diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 5a650dd2..19eee9b3 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -177,10 +177,10 @@ def get_peak_memory() -> tuple[int, int]: r"""Get the peak memory usage for the current device (in Bytes).""" if is_torch_npu_available(): return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved() - elif is_torch_cuda_available(): - return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved() elif is_torch_xpu_available(): return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved() + elif is_torch_cuda_available(): + return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved() else: return 0, 0 diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 8be0e7bd..1706d177 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -97,7 +97,7 @@ def patch_config( if is_torch_npu_available(): # avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458 - torch.npu.set_compile_mode(jit_compile=is_env_enabled("JIT_COMPILE")) + torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE")) configure_attn_implementation(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable)