From 8543400584e7b774af13e3b863208ec588cd47eb Mon Sep 17 00:00:00 2001 From: ENg-122 <86090741+ENg-122@users.noreply.github.com> Date: Wed, 16 Apr 2025 21:48:23 +0800 Subject: [PATCH] [misc] improve entrypoint (#7345) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 纯粹优化下入口代码,因为看到if else太多了 * Update cli.py --------- Co-authored-by: hoshi-hiyouga --- src/llamafactory/cli.py | 121 +++++++++++++++++----------------------- 1 file changed, 51 insertions(+), 70 deletions(-) diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 99a089b2..92085c80 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -17,9 +17,11 @@ import subprocess import sys from copy import deepcopy from enum import Enum, unique +from functools import partial from .extras import logging + USAGE = ( "-" * 70 + "\n" @@ -38,19 +40,6 @@ USAGE = ( logger = logging.get_logger(__name__) -@unique -class Command(str, Enum): - API = "api" - CHAT = "chat" - ENV = "env" - EVAL = "eval" - EXPORT = "export" - TRAIN = "train" - WEBDEMO = "webchat" - WEBUI = "webui" - VER = "version" - HELP = "help" - def main(): from . import launcher from .api.app import run_api @@ -73,67 +62,59 @@ def main(): + "-" * 58 ) - command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP - if command == Command.API: - run_api() - elif command == Command.CHAT: - run_chat() - elif command == Command.ENV: - print_env() - elif command == Command.EVAL: - run_eval() - elif command == Command.EXPORT: - export_model() - elif command == Command.TRAIN: - force_torchrun = is_env_enabled("FORCE_TORCHRUN") - if force_torchrun or (get_device_count() > 1 and not use_ray()): - nnodes = os.getenv("NNODES", "1") - node_rank = os.getenv("NODE_RANK", "0") - nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) - master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") - master_port = os.getenv("MASTER_PORT", str(find_available_port())) - logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}") - if int(nnodes) > 1: - print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}") + COMMANDS = { + "api": run_api, + "chat": run_chat, + "env": print_env, + "eval": run_eval, + "export": export_model, + "train": run_exp, + "webchat": run_web_demo, + "webui": run_web_ui, + "version": partial(print, WELCOME), + "help": partial(print, USAGE), + } - env = deepcopy(os.environ) - if is_env_enabled("OPTIM_TORCH", "1"): - # optimize DDP, see https://zhuanlan.zhihu.com/p/671834539 - env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + command = sys.argv.pop(1) if len(sys.argv) != 1 else "help" + force_torchrun = is_env_enabled("FORCE_TORCHRUN") + if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())): + nnodes = os.getenv("NNODES", "1") + node_rank = os.getenv("NODE_RANK", "0") + nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) + master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") + master_port = os.getenv("MASTER_PORT", str(find_available_port())) + logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}") + if int(nnodes) > 1: + print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}") - # NOTE: DO NOT USE shell=True to avoid security risk - process = subprocess.run( - ( - "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " - "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" - ) - .format( - nnodes=nnodes, - node_rank=node_rank, - nproc_per_node=nproc_per_node, - master_addr=master_addr, - master_port=master_port, - file_name=launcher.__file__, - args=" ".join(sys.argv[1:]), - ) - .split(), - env=env, - check=True, + env = deepcopy(os.environ) + if is_env_enabled("OPTIM_TORCH", "1"): + # optimize DDP, see https://zhuanlan.zhihu.com/p/671834539 + env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE: DO NOT USE shell=True to avoid security risk + process = subprocess.run( + ( + "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " + "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" ) - sys.exit(process.returncode) - else: - run_exp() - elif command == Command.WEBDEMO: - run_web_demo() - elif command == Command.WEBUI: - run_web_ui() - elif command == Command.VER: - print(WELCOME) - elif command == Command.HELP: - print(USAGE) + .format( + nnodes=nnodes, + node_rank=node_rank, + nproc_per_node=nproc_per_node, + master_addr=master_addr, + master_port=master_port, + file_name=launcher.__file__, + args=" ".join(sys.argv[1:]), + ) + .split(), + env=env, + check=True, + ) + sys.exit(process.returncode) else: - print(f"Unknown command: {command}.\n{USAGE}") + COMMANDS[command]() if __name__ == "__main__":