mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
142 lines
4.8 KiB
Python
142 lines
4.8 KiB
Python
# Copyright 2025 the LlamaFactory team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from copy import deepcopy
|
|
from enum import Enum, unique
|
|
|
|
from . import launcher
|
|
from .api.app import run_api
|
|
from .chat.chat_model import run_chat
|
|
from .eval.evaluator import run_eval
|
|
from .extras import logging
|
|
from .extras.env import VERSION, print_env
|
|
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
|
from .train.tuner import export_model, run_exp
|
|
from .webui.interface import run_web_demo, run_web_ui
|
|
|
|
|
|
USAGE = (
|
|
"-" * 70
|
|
+ "\n"
|
|
+ "| Usage: |\n"
|
|
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
|
|
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
|
|
+ "| llamafactory-cli eval -h: evaluate models |\n"
|
|
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
|
|
+ "| llamafactory-cli train -h: train models |\n"
|
|
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
|
|
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
|
|
+ "| llamafactory-cli version: show version info |\n"
|
|
+ "-" * 70
|
|
)
|
|
|
|
WELCOME = (
|
|
"-" * 58
|
|
+ "\n"
|
|
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
|
+ " " * (21 - len(VERSION))
|
|
+ "|\n|"
|
|
+ " " * 56
|
|
+ "|\n"
|
|
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
|
+ "-" * 58
|
|
)
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@unique
|
|
class Command(str, Enum):
|
|
API = "api"
|
|
CHAT = "chat"
|
|
ENV = "env"
|
|
EVAL = "eval"
|
|
EXPORT = "export"
|
|
TRAIN = "train"
|
|
WEBDEMO = "webchat"
|
|
WEBUI = "webui"
|
|
VER = "version"
|
|
HELP = "help"
|
|
|
|
|
|
def main():
|
|
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
|
|
if command == Command.API:
|
|
run_api()
|
|
elif command == Command.CHAT:
|
|
run_chat()
|
|
elif command == Command.ENV:
|
|
print_env()
|
|
elif command == Command.EVAL:
|
|
run_eval()
|
|
elif command == Command.EXPORT:
|
|
export_model()
|
|
elif command == Command.TRAIN:
|
|
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
|
|
if force_torchrun or (get_device_count() > 1 and not use_ray()):
|
|
nnodes = os.getenv("NNODES", "1")
|
|
node_rank = os.getenv("NODE_RANK", "0")
|
|
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
|
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
|
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
|
|
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
|
|
if int(nnodes) > 1:
|
|
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
|
|
|
env = deepcopy(os.environ)
|
|
if is_env_enabled("OPTIM_TORCH", "1"):
|
|
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
|
|
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
|
|
|
# NOTE: DO NOT USE shell=True to avoid security risk
|
|
process = subprocess.run(
|
|
(
|
|
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
|
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
|
)
|
|
.format(
|
|
nnodes=nnodes,
|
|
node_rank=node_rank,
|
|
nproc_per_node=nproc_per_node,
|
|
master_addr=master_addr,
|
|
master_port=master_port,
|
|
file_name=launcher.__file__,
|
|
args=" ".join(sys.argv[1:]),
|
|
)
|
|
.split(),
|
|
env=env,
|
|
check=True,
|
|
)
|
|
sys.exit(process.returncode)
|
|
else:
|
|
run_exp()
|
|
elif command == Command.WEBDEMO:
|
|
run_web_demo()
|
|
elif command == Command.WEBUI:
|
|
run_web_ui()
|
|
elif command == Command.VER:
|
|
print(WELCOME)
|
|
elif command == Command.HELP:
|
|
print(USAGE)
|
|
else:
|
|
print(f"Unknown command: {command}.\n{USAGE}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|