mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[misc] improve entrypoint (#7345)
* 纯粹优化下入口代码,因为看到if else太多了 * Update cli.py --------- Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
		
							parent
							
								
									b9263ff5ac
								
							
						
					
					
						commit
						8f88a4e6a4
					
				@ -17,9 +17,11 @@ import subprocess
 | 
			
		||||
import sys
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from enum import Enum, unique
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
from .extras import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
USAGE = (
 | 
			
		||||
    "-" * 70
 | 
			
		||||
    + "\n"
 | 
			
		||||
@ -38,19 +40,6 @@ USAGE = (
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@unique
 | 
			
		||||
class Command(str, Enum):
 | 
			
		||||
    API = "api"
 | 
			
		||||
    CHAT = "chat"
 | 
			
		||||
    ENV = "env"
 | 
			
		||||
    EVAL = "eval"
 | 
			
		||||
    EXPORT = "export"
 | 
			
		||||
    TRAIN = "train"
 | 
			
		||||
    WEBDEMO = "webchat"
 | 
			
		||||
    WEBUI = "webui"
 | 
			
		||||
    VER = "version"
 | 
			
		||||
    HELP = "help"
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    from . import launcher
 | 
			
		||||
    from .api.app import run_api
 | 
			
		||||
@ -73,67 +62,59 @@ def main():
 | 
			
		||||
        + "-" * 58
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
 | 
			
		||||
    if command == Command.API:
 | 
			
		||||
        run_api()
 | 
			
		||||
    elif command == Command.CHAT:
 | 
			
		||||
        run_chat()
 | 
			
		||||
    elif command == Command.ENV:
 | 
			
		||||
        print_env()
 | 
			
		||||
    elif command == Command.EVAL:
 | 
			
		||||
        run_eval()
 | 
			
		||||
    elif command == Command.EXPORT:
 | 
			
		||||
        export_model()
 | 
			
		||||
    elif command == Command.TRAIN:
 | 
			
		||||
        force_torchrun = is_env_enabled("FORCE_TORCHRUN")
 | 
			
		||||
        if force_torchrun or (get_device_count() > 1 and not use_ray()):
 | 
			
		||||
            nnodes = os.getenv("NNODES", "1")
 | 
			
		||||
            node_rank = os.getenv("NODE_RANK", "0")
 | 
			
		||||
            nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
 | 
			
		||||
            master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
 | 
			
		||||
            master_port = os.getenv("MASTER_PORT", str(find_available_port()))
 | 
			
		||||
            logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
 | 
			
		||||
            if int(nnodes) > 1:
 | 
			
		||||
                print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
 | 
			
		||||
    COMMANDS = {
 | 
			
		||||
        "api": run_api,
 | 
			
		||||
        "chat": run_chat,
 | 
			
		||||
        "env": print_env,
 | 
			
		||||
        "eval": run_eval,
 | 
			
		||||
        "export": export_model,
 | 
			
		||||
        "train": run_exp,
 | 
			
		||||
        "webchat": run_web_demo,
 | 
			
		||||
        "webui": run_web_ui,
 | 
			
		||||
        "version": partial(print, WELCOME),
 | 
			
		||||
        "help": partial(print, USAGE),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
            env = deepcopy(os.environ)
 | 
			
		||||
            if is_env_enabled("OPTIM_TORCH", "1"):
 | 
			
		||||
                # optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
 | 
			
		||||
                env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
 | 
			
		||||
                env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
 | 
			
		||||
    command = sys.argv.pop(1) if len(sys.argv) != 1 else "help"
 | 
			
		||||
    force_torchrun = is_env_enabled("FORCE_TORCHRUN")
 | 
			
		||||
    if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())):
 | 
			
		||||
        nnodes = os.getenv("NNODES", "1")
 | 
			
		||||
        node_rank = os.getenv("NODE_RANK", "0")
 | 
			
		||||
        nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
 | 
			
		||||
        master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
 | 
			
		||||
        master_port = os.getenv("MASTER_PORT", str(find_available_port()))
 | 
			
		||||
        logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
 | 
			
		||||
        if int(nnodes) > 1:
 | 
			
		||||
            print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
 | 
			
		||||
 | 
			
		||||
            # NOTE: DO NOT USE shell=True to avoid security risk
 | 
			
		||||
            process = subprocess.run(
 | 
			
		||||
                (
 | 
			
		||||
                    "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
 | 
			
		||||
                    "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
 | 
			
		||||
                )
 | 
			
		||||
                .format(
 | 
			
		||||
                    nnodes=nnodes,
 | 
			
		||||
                    node_rank=node_rank,
 | 
			
		||||
                    nproc_per_node=nproc_per_node,
 | 
			
		||||
                    master_addr=master_addr,
 | 
			
		||||
                    master_port=master_port,
 | 
			
		||||
                    file_name=launcher.__file__,
 | 
			
		||||
                    args=" ".join(sys.argv[1:]),
 | 
			
		||||
                )
 | 
			
		||||
                .split(),
 | 
			
		||||
                env=env,
 | 
			
		||||
                check=True,
 | 
			
		||||
        env = deepcopy(os.environ)
 | 
			
		||||
        if is_env_enabled("OPTIM_TORCH", "1"):
 | 
			
		||||
            # optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
 | 
			
		||||
            env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
 | 
			
		||||
            env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
 | 
			
		||||
 | 
			
		||||
        # NOTE: DO NOT USE shell=True to avoid security risk
 | 
			
		||||
        process = subprocess.run(
 | 
			
		||||
            (
 | 
			
		||||
                "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
 | 
			
		||||
                "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
 | 
			
		||||
            )
 | 
			
		||||
            sys.exit(process.returncode)
 | 
			
		||||
        else:
 | 
			
		||||
            run_exp()
 | 
			
		||||
    elif command == Command.WEBDEMO:
 | 
			
		||||
        run_web_demo()
 | 
			
		||||
    elif command == Command.WEBUI:
 | 
			
		||||
        run_web_ui()
 | 
			
		||||
    elif command == Command.VER:
 | 
			
		||||
        print(WELCOME)
 | 
			
		||||
    elif command == Command.HELP:
 | 
			
		||||
        print(USAGE)
 | 
			
		||||
            .format(
 | 
			
		||||
                nnodes=nnodes,
 | 
			
		||||
                node_rank=node_rank,
 | 
			
		||||
                nproc_per_node=nproc_per_node,
 | 
			
		||||
                master_addr=master_addr,
 | 
			
		||||
                master_port=master_port,
 | 
			
		||||
                file_name=launcher.__file__,
 | 
			
		||||
                args=" ".join(sys.argv[1:]),
 | 
			
		||||
            )
 | 
			
		||||
            .split(),
 | 
			
		||||
            env=env,
 | 
			
		||||
            check=True,
 | 
			
		||||
        )
 | 
			
		||||
        sys.exit(process.returncode)
 | 
			
		||||
    else:
 | 
			
		||||
        print(f"Unknown command: {command}.\n{USAGE}")
 | 
			
		||||
        COMMANDS[command]()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user