[misc] improve entrypoint (#7345)

* 纯粹优化下入口代码,因为看到if else太多了

* Update cli.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
ENg-122 2025-04-16 21:48:23 +08:00 committed by GitHub
parent e1fdd6e2f8
commit 8543400584
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,9 +17,11 @@ import subprocess
import sys
from copy import deepcopy
from enum import Enum, unique
from functools import partial
from .extras import logging
USAGE = (
"-" * 70
+ "\n"
@ -38,19 +40,6 @@ USAGE = (
logger = logging.get_logger(__name__)
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
ENV = "env"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
VER = "version"
HELP = "help"
def main():
from . import launcher
from .api.app import run_api
@ -73,20 +62,22 @@ def main():
+ "-" * 58
)
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
if command == Command.API:
run_api()
elif command == Command.CHAT:
run_chat()
elif command == Command.ENV:
print_env()
elif command == Command.EVAL:
run_eval()
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
COMMANDS = {
"api": run_api,
"chat": run_chat,
"env": print_env,
"eval": run_eval,
"export": export_model,
"train": run_exp,
"webchat": run_web_demo,
"webui": run_web_ui,
"version": partial(print, WELCOME),
"help": partial(print, USAGE),
}
command = sys.argv.pop(1) if len(sys.argv) != 1 else "help"
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
if force_torchrun or (get_device_count() > 1 and not use_ray()):
if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())):
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
@ -123,17 +114,7 @@ def main():
)
sys.exit(process.returncode)
else:
run_exp()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:
run_web_ui()
elif command == Command.VER:
print(WELCOME)
elif command == Command.HELP:
print(USAGE)
else:
print(f"Unknown command: {command}.\n{USAGE}")
COMMANDS[command]()
if __name__ == "__main__":