diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 6fb4b236..9649cd69 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -36,52 +36,16 @@ USAGE = ( ) -def _run_api(): - from .api.app import run_api - - return run_api() - - -def _run_chat(): - from .chat.chat_model import run_chat - - return run_chat() - - -def _run_eval(): - raise NotImplementedError("Evaluation will be deprecated in the future.") - - -def _export_model(): - from .train.tuner import export_model - - return export_model() - - -def _run_exp(): - from .train.tuner import run_exp - - return run_exp() - - -def _run_web_demo(): - from .webui.interface import run_web_demo - - return run_web_demo() - - -def _run_web_ui(): - from .webui.interface import run_web_ui - - return run_web_ui() - - def main(): - from . import launcher from .extras import logging from .extras.env import VERSION, print_env from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray + if is_env_enabled("USE_V1"): + from .v1 import launcher + else: + from . import launcher + logger = logging.get_logger(__name__) WELCOME = ( @@ -97,14 +61,14 @@ def main(): ) COMMAND_MAP = { - "api": _run_api, - "chat": _run_chat, + "api": launcher.run_api, + "chat": launcher.run_chat, "env": print_env, - "eval": _run_eval, - "export": _export_model, - "train": _run_exp, - "webchat": _run_web_demo, - "webui": _run_web_ui, + "eval": launcher.run_eval, + "export": launcher.export_model, + "train": launcher.run_exp, + "webchat": launcher.run_web_demo, + "webui": launcher.run_web_ui, "version": partial(print, WELCOME), "help": partial(print, USAGE), } diff --git a/src/llamafactory/launcher.py b/src/llamafactory/launcher.py index 8d1435cf..8b4ae586 100644 --- a/src/llamafactory/launcher.py +++ b/src/llamafactory/launcher.py @@ -13,11 +13,45 @@ # limitations under the License. -def launch(): - from llamafactory.train.tuner import run_exp # use absolute import +def run_api(): + from llamafactory.api.app import run_api as _run_api - run_exp() + _run_api() + + +def run_chat(): + from llamafactory.chat.chat_model import run_chat as _run_chat + + return _run_chat() + + +def run_eval(): + raise NotImplementedError("Evaluation will be deprecated in the future.") + + +def export_model(): + from llamafactory.train.tuner import export_model as _export_model + + return _export_model() + + +def run_exp(): + from llamafactory.train.tuner import run_exp as _run_exp + + return _run_exp() # use absolute import + + +def run_web_demo(): + from llamafactory.webui.interface import run_web_demo as _run_web_demo + + return _run_web_demo() + + +def run_web_ui(): + from llamafactory.webui.interface import run_web_ui as _run_web_ui + + return _run_web_ui() if __name__ == "__main__": - launch() + run_exp() diff --git a/src/llamafactory/v1/launcher.py b/src/llamafactory/v1/launcher.py index e69de29b..36f87403 100644 --- a/src/llamafactory/v1/launcher.py +++ b/src/llamafactory/v1/launcher.py @@ -0,0 +1,33 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def run_train(): + raise NotImplementedError("Please use `llamafactory-cli sft` or `llamafactory-cli rm`.") + + +def run_chat(): + from llamafactory.v1.core.chat_sampler import Sampler + + Sampler().cli() + + +def run_sft(): + from llamafactory.v1.train.sft import SFTTrainer + + SFTTrainer().run() + + +if __name__ == "__main__": + run_train()