[misc] upgrade cli (#7714)

This commit is contained in:
hoshi-hiyouga
2025-04-14 15:41:22 +08:00
committed by GitHub
parent c60971f4b8
commit 8f46aced51
6 changed files with 26 additions and 10 deletions

View File

@@ -15,6 +15,7 @@
import os
import subprocess
import sys
from copy import deepcopy
from enum import Enum, unique
from . import launcher
@@ -96,6 +97,13 @@ def main():
if int(nnodes) > 1:
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
env = deepcopy(os.environ)
if is_env_enabled("OPTIM_TORCH", "1"):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE: DO NOT USE shell=True to avoid security risk
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
@@ -110,7 +118,9 @@ def main():
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
)
.split()
.split(),
env=env,
check=True,
)
sys.exit(process.returncode)
else: