mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
[cli] support lazy import (#9217)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
parent
6ffebe5ff7
commit
a04d777d7f
@ -24,9 +24,6 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
from ..extras.constants import EngineName
|
from ..extras.constants import EngineName
|
||||||
from ..extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
from ..hparams import get_infer_args
|
from ..hparams import get_infer_args
|
||||||
from .hf_engine import HuggingfaceEngine
|
|
||||||
from .sglang_engine import SGLangEngine
|
|
||||||
from .vllm_engine import VllmEngine
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -49,12 +46,28 @@ class ChatModel:
|
|||||||
|
|
||||||
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||||
|
|
||||||
if model_args.infer_backend == EngineName.HF:
|
if model_args.infer_backend == EngineName.HF:
|
||||||
|
from .hf_engine import HuggingfaceEngine
|
||||||
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
elif model_args.infer_backend == EngineName.VLLM:
|
elif model_args.infer_backend == EngineName.VLLM:
|
||||||
|
try:
|
||||||
|
from .vllm_engine import VllmEngine
|
||||||
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"vLLM not install, you may need to run `pip install vllm`\n"
|
||||||
|
"or try to use HuggingFace backend: --infer_backend huggingface"
|
||||||
|
) from e
|
||||||
elif model_args.infer_backend == EngineName.SGLANG:
|
elif model_args.infer_backend == EngineName.SGLANG:
|
||||||
|
try:
|
||||||
|
from .sglang_engine import SGLangEngine
|
||||||
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"SGLang not install, you may need to run `pip install sglang[all]`\n"
|
||||||
|
"or try to use HuggingFace backend: --infer_backend huggingface"
|
||||||
|
) from e
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||||
|
|
||||||
|
@ -35,16 +35,46 @@ USAGE = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_api():
|
||||||
|
from .api.app import run_api
|
||||||
|
return run_api()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_chat():
|
||||||
|
from .chat.chat_model import run_chat
|
||||||
|
return run_chat()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_eval():
|
||||||
|
from .eval.evaluator import run_eval
|
||||||
|
return run_eval()
|
||||||
|
|
||||||
|
|
||||||
|
def _export_model():
|
||||||
|
from .train.tuner import export_model
|
||||||
|
return export_model()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_exp():
|
||||||
|
from .train.tuner import run_exp
|
||||||
|
return run_exp()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_web_demo():
|
||||||
|
from .webui.interface import run_web_demo
|
||||||
|
return run_web_demo()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_web_ui():
|
||||||
|
from .webui.interface import run_web_ui
|
||||||
|
return run_web_ui()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
from . import launcher
|
from . import launcher
|
||||||
from .api.app import run_api
|
|
||||||
from .chat.chat_model import run_chat
|
|
||||||
from .eval.evaluator import run_eval
|
|
||||||
from .extras import logging
|
from .extras import logging
|
||||||
from .extras.env import VERSION, print_env
|
from .extras.env import VERSION, print_env
|
||||||
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
||||||
from .train.tuner import export_model, run_exp
|
|
||||||
from .webui.interface import run_web_demo, run_web_ui
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -61,14 +91,14 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
COMMAND_MAP = {
|
COMMAND_MAP = {
|
||||||
"api": run_api,
|
"api": _run_api,
|
||||||
"chat": run_chat,
|
"chat": _run_chat,
|
||||||
"env": print_env,
|
"env": print_env,
|
||||||
"eval": run_eval,
|
"eval": _run_eval,
|
||||||
"export": export_model,
|
"export": _export_model,
|
||||||
"train": run_exp,
|
"train": _run_exp,
|
||||||
"webchat": run_web_demo,
|
"webchat": _run_web_demo,
|
||||||
"webui": run_web_ui,
|
"webui": _run_web_ui,
|
||||||
"version": partial(print, WELCOME),
|
"version": partial(print, WELCOME),
|
||||||
"help": partial(print, USAGE),
|
"help": partial(print, USAGE),
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user