[misc] improve entrypoint (#7345)

* 纯粹优化下入口代码,因为看到if else太多了

* Update cli.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
ENg-122 2025-04-16 21:48:23 +08:00 committed by GitHub
parent e1fdd6e2f8
commit 8543400584
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,9 +17,11 @@ import subprocess
import sys import sys
from copy import deepcopy from copy import deepcopy
from enum import Enum, unique from enum import Enum, unique
from functools import partial
from .extras import logging from .extras import logging
USAGE = ( USAGE = (
"-" * 70 "-" * 70
+ "\n" + "\n"
@ -38,19 +40,6 @@ USAGE = (
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
ENV = "env"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
VER = "version"
HELP = "help"
def main(): def main():
from . import launcher from . import launcher
from .api.app import run_api from .api.app import run_api
@ -73,67 +62,59 @@ def main():
+ "-" * 58 + "-" * 58
) )
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP COMMANDS = {
if command == Command.API: "api": run_api,
run_api() "chat": run_chat,
elif command == Command.CHAT: "env": print_env,
run_chat() "eval": run_eval,
elif command == Command.ENV: "export": export_model,
print_env() "train": run_exp,
elif command == Command.EVAL: "webchat": run_web_demo,
run_eval() "webui": run_web_ui,
elif command == Command.EXPORT: "version": partial(print, WELCOME),
export_model() "help": partial(print, USAGE),
elif command == Command.TRAIN: }
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
if force_torchrun or (get_device_count() > 1 and not use_ray()):
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
env = deepcopy(os.environ) command = sys.argv.pop(1) if len(sys.argv) != 1 else "help"
if is_env_enabled("OPTIM_TORCH", "1"): force_torchrun = is_env_enabled("FORCE_TORCHRUN")
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539 if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())):
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" nnodes = os.getenv("NNODES", "1")
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
# NOTE: DO NOT USE shell=True to avoid security risk env = deepcopy(os.environ)
process = subprocess.run( if is_env_enabled("OPTIM_TORCH", "1"):
( # optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}" env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
)
.format( # NOTE: DO NOT USE shell=True to avoid security risk
nnodes=nnodes, process = subprocess.run(
node_rank=node_rank, (
nproc_per_node=nproc_per_node, "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
master_addr=master_addr, "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
)
.split(),
env=env,
check=True,
) )
sys.exit(process.returncode) .format(
else: nnodes=nnodes,
run_exp() node_rank=node_rank,
elif command == Command.WEBDEMO: nproc_per_node=nproc_per_node,
run_web_demo() master_addr=master_addr,
elif command == Command.WEBUI: master_port=master_port,
run_web_ui() file_name=launcher.__file__,
elif command == Command.VER: args=" ".join(sys.argv[1:]),
print(WELCOME) )
elif command == Command.HELP: .split(),
print(USAGE) env=env,
check=True,
)
sys.exit(process.returncode)
else: else:
print(f"Unknown command: {command}.\n{USAGE}") COMMANDS[command]()
if __name__ == "__main__": if __name__ == "__main__":