mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[cli] support lazy import (#9217)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
		
							parent
							
								
									6ffebe5ff7
								
							
						
					
					
						commit
						a04d777d7f
					
				@ -24,9 +24,6 @@ from typing import TYPE_CHECKING, Any, Optional
 | 
			
		||||
from ..extras.constants import EngineName
 | 
			
		||||
from ..extras.misc import torch_gc
 | 
			
		||||
from ..hparams import get_infer_args
 | 
			
		||||
from .hf_engine import HuggingfaceEngine
 | 
			
		||||
from .sglang_engine import SGLangEngine
 | 
			
		||||
from .vllm_engine import VllmEngine
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -49,12 +46,28 @@ class ChatModel:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
 | 
			
		||||
        model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
 | 
			
		||||
 | 
			
		||||
        if model_args.infer_backend == EngineName.HF:
 | 
			
		||||
            from .hf_engine import HuggingfaceEngine
 | 
			
		||||
            self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
 | 
			
		||||
        elif model_args.infer_backend == EngineName.VLLM:
 | 
			
		||||
            self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
 | 
			
		||||
            try:
 | 
			
		||||
                from .vllm_engine import VllmEngine
 | 
			
		||||
                self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
 | 
			
		||||
            except ImportError as e:
 | 
			
		||||
                raise ImportError(
 | 
			
		||||
                    "vLLM not install, you may need to run `pip install vllm`\n"
 | 
			
		||||
                    "or try to use HuggingFace backend: --infer_backend huggingface"
 | 
			
		||||
                ) from e
 | 
			
		||||
        elif model_args.infer_backend == EngineName.SGLANG:
 | 
			
		||||
            self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
 | 
			
		||||
            try:
 | 
			
		||||
                from .sglang_engine import SGLangEngine
 | 
			
		||||
                self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
 | 
			
		||||
            except ImportError as e:
 | 
			
		||||
                raise ImportError(
 | 
			
		||||
                    "SGLang not install, you may need to run `pip install sglang[all]`\n"
 | 
			
		||||
                    "or try to use HuggingFace backend: --infer_backend huggingface"
 | 
			
		||||
                ) from e
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -35,16 +35,46 @@ USAGE = (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_api():
 | 
			
		||||
    from .api.app import run_api
 | 
			
		||||
    return run_api()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_chat():
 | 
			
		||||
    from .chat.chat_model import run_chat
 | 
			
		||||
    return run_chat()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_eval():
 | 
			
		||||
    from .eval.evaluator import run_eval
 | 
			
		||||
    return run_eval()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _export_model():
 | 
			
		||||
    from .train.tuner import export_model
 | 
			
		||||
    return export_model()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_exp():
 | 
			
		||||
    from .train.tuner import run_exp
 | 
			
		||||
    return run_exp()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_web_demo():
 | 
			
		||||
    from .webui.interface import run_web_demo
 | 
			
		||||
    return run_web_demo()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_web_ui():
 | 
			
		||||
    from .webui.interface import run_web_ui
 | 
			
		||||
    return run_web_ui()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    from . import launcher
 | 
			
		||||
    from .api.app import run_api
 | 
			
		||||
    from .chat.chat_model import run_chat
 | 
			
		||||
    from .eval.evaluator import run_eval
 | 
			
		||||
    from .extras import logging
 | 
			
		||||
    from .extras.env import VERSION, print_env
 | 
			
		||||
    from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
 | 
			
		||||
    from .train.tuner import export_model, run_exp
 | 
			
		||||
    from .webui.interface import run_web_demo, run_web_ui
 | 
			
		||||
 | 
			
		||||
    logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
@ -61,14 +91,14 @@ def main():
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    COMMAND_MAP = {
 | 
			
		||||
        "api": run_api,
 | 
			
		||||
        "chat": run_chat,
 | 
			
		||||
        "api": _run_api,
 | 
			
		||||
        "chat": _run_chat,
 | 
			
		||||
        "env": print_env,
 | 
			
		||||
        "eval": run_eval,
 | 
			
		||||
        "export": export_model,
 | 
			
		||||
        "train": run_exp,
 | 
			
		||||
        "webchat": run_web_demo,
 | 
			
		||||
        "webui": run_web_ui,
 | 
			
		||||
        "eval": _run_eval,
 | 
			
		||||
        "export": _export_model,
 | 
			
		||||
        "train": _run_exp,
 | 
			
		||||
        "webchat": _run_web_demo,
 | 
			
		||||
        "webui": _run_web_ui,
 | 
			
		||||
        "version": partial(print, WELCOME),
 | 
			
		||||
        "help": partial(print, USAGE),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user