mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 02:00:36 +08:00
[misc] upgrade cli (#7714)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from enum import Enum, unique
|
||||
|
||||
from . import launcher
|
||||
@@ -96,6 +97,13 @@ def main():
|
||||
if int(nnodes) > 1:
|
||||
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
||||
|
||||
env = deepcopy(os.environ)
|
||||
if is_env_enabled("OPTIM_TORCH", "1"):
|
||||
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
|
||||
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||
process = subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||
@@ -110,7 +118,9 @@ def main():
|
||||
file_name=launcher.__file__,
|
||||
args=" ".join(sys.argv[1:]),
|
||||
)
|
||||
.split()
|
||||
.split(),
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
sys.exit(process.returncode)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user