mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
[misc] improve entrypoint (#7345)
* 纯粹优化下入口代码,因为看到if else太多了 * Update cli.py --------- Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
parent
e1fdd6e2f8
commit
8543400584
@ -17,9 +17,11 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from .extras import logging
|
from .extras import logging
|
||||||
|
|
||||||
|
|
||||||
USAGE = (
|
USAGE = (
|
||||||
"-" * 70
|
"-" * 70
|
||||||
+ "\n"
|
+ "\n"
|
||||||
@ -38,19 +40,6 @@ USAGE = (
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@unique
|
|
||||||
class Command(str, Enum):
|
|
||||||
API = "api"
|
|
||||||
CHAT = "chat"
|
|
||||||
ENV = "env"
|
|
||||||
EVAL = "eval"
|
|
||||||
EXPORT = "export"
|
|
||||||
TRAIN = "train"
|
|
||||||
WEBDEMO = "webchat"
|
|
||||||
WEBUI = "webui"
|
|
||||||
VER = "version"
|
|
||||||
HELP = "help"
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
from . import launcher
|
from . import launcher
|
||||||
from .api.app import run_api
|
from .api.app import run_api
|
||||||
@ -73,67 +62,59 @@ def main():
|
|||||||
+ "-" * 58
|
+ "-" * 58
|
||||||
)
|
)
|
||||||
|
|
||||||
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
|
COMMANDS = {
|
||||||
if command == Command.API:
|
"api": run_api,
|
||||||
run_api()
|
"chat": run_chat,
|
||||||
elif command == Command.CHAT:
|
"env": print_env,
|
||||||
run_chat()
|
"eval": run_eval,
|
||||||
elif command == Command.ENV:
|
"export": export_model,
|
||||||
print_env()
|
"train": run_exp,
|
||||||
elif command == Command.EVAL:
|
"webchat": run_web_demo,
|
||||||
run_eval()
|
"webui": run_web_ui,
|
||||||
elif command == Command.EXPORT:
|
"version": partial(print, WELCOME),
|
||||||
export_model()
|
"help": partial(print, USAGE),
|
||||||
elif command == Command.TRAIN:
|
}
|
||||||
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
|
|
||||||
if force_torchrun or (get_device_count() > 1 and not use_ray()):
|
|
||||||
nnodes = os.getenv("NNODES", "1")
|
|
||||||
node_rank = os.getenv("NODE_RANK", "0")
|
|
||||||
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
|
||||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
|
||||||
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
|
|
||||||
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
|
|
||||||
if int(nnodes) > 1:
|
|
||||||
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
|
||||||
|
|
||||||
env = deepcopy(os.environ)
|
command = sys.argv.pop(1) if len(sys.argv) != 1 else "help"
|
||||||
if is_env_enabled("OPTIM_TORCH", "1"):
|
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
|
||||||
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
|
if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())):
|
||||||
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
nnodes = os.getenv("NNODES", "1")
|
||||||
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
node_rank = os.getenv("NODE_RANK", "0")
|
||||||
|
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
||||||
|
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||||
|
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
|
||||||
|
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
|
||||||
|
if int(nnodes) > 1:
|
||||||
|
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
||||||
|
|
||||||
# NOTE: DO NOT USE shell=True to avoid security risk
|
env = deepcopy(os.environ)
|
||||||
process = subprocess.run(
|
if is_env_enabled("OPTIM_TORCH", "1"):
|
||||||
(
|
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
|
||||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
)
|
|
||||||
.format(
|
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||||
nnodes=nnodes,
|
process = subprocess.run(
|
||||||
node_rank=node_rank,
|
(
|
||||||
nproc_per_node=nproc_per_node,
|
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||||
master_addr=master_addr,
|
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||||
master_port=master_port,
|
|
||||||
file_name=launcher.__file__,
|
|
||||||
args=" ".join(sys.argv[1:]),
|
|
||||||
)
|
|
||||||
.split(),
|
|
||||||
env=env,
|
|
||||||
check=True,
|
|
||||||
)
|
)
|
||||||
sys.exit(process.returncode)
|
.format(
|
||||||
else:
|
nnodes=nnodes,
|
||||||
run_exp()
|
node_rank=node_rank,
|
||||||
elif command == Command.WEBDEMO:
|
nproc_per_node=nproc_per_node,
|
||||||
run_web_demo()
|
master_addr=master_addr,
|
||||||
elif command == Command.WEBUI:
|
master_port=master_port,
|
||||||
run_web_ui()
|
file_name=launcher.__file__,
|
||||||
elif command == Command.VER:
|
args=" ".join(sys.argv[1:]),
|
||||||
print(WELCOME)
|
)
|
||||||
elif command == Command.HELP:
|
.split(),
|
||||||
print(USAGE)
|
env=env,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
sys.exit(process.returncode)
|
||||||
else:
|
else:
|
||||||
print(f"Unknown command: {command}.\n{USAGE}")
|
COMMANDS[command]()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user