[misc] improve entrypoint (#7345)

* 纯粹优化下入口代码,因为看到if else太多了

* Update cli.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
ENg-122 2025-04-16 21:48:23 +08:00 committed by GitHub
parent e1fdd6e2f8
commit 8543400584
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,9 +17,11 @@ import subprocess
import sys import sys
from copy import deepcopy from copy import deepcopy
from enum import Enum, unique from enum import Enum, unique
from functools import partial
from .extras import logging from .extras import logging
USAGE = ( USAGE = (
"-" * 70 "-" * 70
+ "\n" + "\n"
@ -38,19 +40,6 @@ USAGE = (
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
ENV = "env"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
VER = "version"
HELP = "help"
def main(): def main():
from . import launcher from . import launcher
from .api.app import run_api from .api.app import run_api
@ -73,20 +62,22 @@ def main():
+ "-" * 58 + "-" * 58
) )
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP COMMANDS = {
if command == Command.API: "api": run_api,
run_api() "chat": run_chat,
elif command == Command.CHAT: "env": print_env,
run_chat() "eval": run_eval,
elif command == Command.ENV: "export": export_model,
print_env() "train": run_exp,
elif command == Command.EVAL: "webchat": run_web_demo,
run_eval() "webui": run_web_ui,
elif command == Command.EXPORT: "version": partial(print, WELCOME),
export_model() "help": partial(print, USAGE),
elif command == Command.TRAIN: }
command = sys.argv.pop(1) if len(sys.argv) != 1 else "help"
force_torchrun = is_env_enabled("FORCE_TORCHRUN") force_torchrun = is_env_enabled("FORCE_TORCHRUN")
if force_torchrun or (get_device_count() > 1 and not use_ray()): if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())):
nnodes = os.getenv("NNODES", "1") nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0") node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
@ -123,17 +114,7 @@ def main():
) )
sys.exit(process.returncode) sys.exit(process.returncode)
else: else:
run_exp() COMMANDS[command]()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:
run_web_ui()
elif command == Command.VER:
print(WELCOME)
elif command == Command.HELP:
print(USAGE)
else:
print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__": if __name__ == "__main__":