[misc] improve entrypoint (#7345)

* 纯粹优化下入口代码,因为看到if else太多了

* Update cli.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
ENg-122 2025-04-16 21:48:23 +08:00 committed by GitHub
parent b9263ff5ac
commit 8f88a4e6a4

View File

@ -17,9 +17,11 @@ import subprocess
import sys import sys
from copy import deepcopy from copy import deepcopy
from enum import Enum, unique from enum import Enum, unique
from functools import partial
from .extras import logging from .extras import logging
USAGE = ( USAGE = (
"-" * 70 "-" * 70
+ "\n" + "\n"
@ -38,19 +40,6 @@ USAGE = (
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
ENV = "env"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
VER = "version"
HELP = "help"
def main(): def main():
from . import launcher from . import launcher
from .api.app import run_api from .api.app import run_api
@ -73,20 +62,22 @@ def main():
+ "-" * 58 + "-" * 58
) )
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP COMMANDS = {
if command == Command.API: "api": run_api,
run_api() "chat": run_chat,
elif command == Command.CHAT: "env": print_env,
run_chat() "eval": run_eval,
elif command == Command.ENV: "export": export_model,
print_env() "train": run_exp,
elif command == Command.EVAL: "webchat": run_web_demo,
run_eval() "webui": run_web_ui,
elif command == Command.EXPORT: "version": partial(print, WELCOME),
export_model() "help": partial(print, USAGE),
elif command == Command.TRAIN: }
command = sys.argv.pop(1) if len(sys.argv) != 1 else "help"
force_torchrun = is_env_enabled("FORCE_TORCHRUN") force_torchrun = is_env_enabled("FORCE_TORCHRUN")
if force_torchrun or (get_device_count() > 1 and not use_ray()): if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())):
nnodes = os.getenv("NNODES", "1") nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0") node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
@ -123,17 +114,7 @@ def main():
) )
sys.exit(process.returncode) sys.exit(process.returncode)
else: else:
run_exp() COMMANDS[command]()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:
run_web_ui()
elif command == Command.VER:
print(WELCOME)
elif command == Command.HELP:
print(USAGE)
else:
print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__": if __name__ == "__main__":