mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
76 lines
2.1 KiB
Python
76 lines
2.1 KiB
Python
import sys
|
|
from enum import Enum, unique
|
|
|
|
from .api.app import run_api
|
|
from .chat.chat_model import run_chat
|
|
from .eval.evaluator import run_eval
|
|
from .train.tuner import export_model, run_exp
|
|
from .webui.interface import run_web_demo, run_web_ui
|
|
|
|
|
|
USAGE = (
|
|
"-" * 70
|
|
+ "\n"
|
|
+ "| Usage: |\n"
|
|
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
|
|
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
|
|
+ "| llamafactory-cli eval -h: evaluate models |\n"
|
|
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
|
|
+ "| llamafactory-cli train -h: train models |\n"
|
|
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
|
|
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
|
|
+ "| llamafactory-cli version: show version info |\n"
|
|
+ "-" * 70
|
|
)
|
|
|
|
VERSION = "0.7.2.dev0"
|
|
|
|
WELCOME = (
|
|
"-" * 58
|
|
+ "\n"
|
|
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
|
|
+ " " * (21 - len(VERSION))
|
|
+ "|\n|"
|
|
+ " " * 56
|
|
+ "|\n"
|
|
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
|
+ "-" * 58
|
|
)
|
|
|
|
|
|
@unique
|
|
class Command(str, Enum):
|
|
API = "api"
|
|
CHAT = "chat"
|
|
EVAL = "eval"
|
|
EXPORT = "export"
|
|
TRAIN = "train"
|
|
WEBDEMO = "webchat"
|
|
WEBUI = "webui"
|
|
VER = "version"
|
|
HELP = "help"
|
|
|
|
|
|
def main():
|
|
command = sys.argv.pop(1)
|
|
if command == Command.API:
|
|
run_api()
|
|
elif command == Command.CHAT:
|
|
run_chat()
|
|
elif command == Command.EVAL:
|
|
run_eval()
|
|
elif command == Command.EXPORT:
|
|
export_model()
|
|
elif command == Command.TRAIN:
|
|
run_exp()
|
|
elif command == Command.WEBDEMO:
|
|
run_web_demo()
|
|
elif command == Command.WEBUI:
|
|
run_web_ui()
|
|
elif command == Command.VER:
|
|
print(WELCOME)
|
|
elif command == Command.HELP:
|
|
print(USAGE)
|
|
else:
|
|
raise NotImplementedError("Unknown command: {}".format(command))
|