diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 0022eed9..b0c9a1cf 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -24,9 +24,6 @@ from typing import TYPE_CHECKING, Any, Optional from ..extras.constants import EngineName from ..extras.misc import torch_gc from ..hparams import get_infer_args -from .hf_engine import HuggingfaceEngine -from .sglang_engine import SGLangEngine -from .vllm_engine import VllmEngine if TYPE_CHECKING: @@ -49,12 +46,28 @@ class ChatModel: def __init__(self, args: Optional[dict[str, Any]] = None) -> None: model_args, data_args, finetuning_args, generating_args = get_infer_args(args) + if model_args.infer_backend == EngineName.HF: + from .hf_engine import HuggingfaceEngine self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) elif model_args.infer_backend == EngineName.VLLM: - self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args) + try: + from .vllm_engine import VllmEngine + self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args) + except ImportError as e: + raise ImportError( + "vLLM not install, you may need to run `pip install vllm`\n" + "or try to use HuggingFace backend: --infer_backend huggingface" + ) from e elif model_args.infer_backend == EngineName.SGLANG: - self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args) + try: + from .sglang_engine import SGLangEngine + self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args) + except ImportError as e: + raise ImportError( + "SGLang not install, you may need to run `pip install sglang[all]`\n" + "or try to use HuggingFace backend: --infer_backend huggingface" + ) from e else: raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index fb94b35c..f7ff80d3 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -35,16 +35,46 @@ USAGE = ( ) +def _run_api(): + from .api.app import run_api + return run_api() + + +def _run_chat(): + from .chat.chat_model import run_chat + return run_chat() + + +def _run_eval(): + from .eval.evaluator import run_eval + return run_eval() + + +def _export_model(): + from .train.tuner import export_model + return export_model() + + +def _run_exp(): + from .train.tuner import run_exp + return run_exp() + + +def _run_web_demo(): + from .webui.interface import run_web_demo + return run_web_demo() + + +def _run_web_ui(): + from .webui.interface import run_web_ui + return run_web_ui() + + def main(): from . import launcher - from .api.app import run_api - from .chat.chat_model import run_chat - from .eval.evaluator import run_eval from .extras import logging from .extras.env import VERSION, print_env from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray - from .train.tuner import export_model, run_exp - from .webui.interface import run_web_demo, run_web_ui logger = logging.get_logger(__name__) @@ -61,14 +91,14 @@ def main(): ) COMMAND_MAP = { - "api": run_api, - "chat": run_chat, + "api": _run_api, + "chat": _run_chat, "env": print_env, - "eval": run_eval, - "export": export_model, - "train": run_exp, - "webchat": run_web_demo, - "webui": run_web_ui, + "eval": _run_eval, + "export": _export_model, + "train": _run_exp, + "webchat": _run_web_demo, + "webui": _run_web_ui, "version": partial(print, WELCOME), "help": partial(print, USAGE), }