diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index 67a19b68..786e743d 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -10,6 +10,7 @@ from .base_engine import BaseEngine, Response if is_vllm_available(): from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams + from vllm.lora.request import LoRARequest if TYPE_CHECKING: from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments @@ -24,7 +25,8 @@ class VllmEngine(BaseEngine): generating_args: "GeneratingArguments", ) -> None: config = load_config(model_args) # may download model from ms hub - load_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + infer_dtype = str(infer_dtype).split(".")[-1] self.can_generate = finetuning_args.stage == "sft" self.tokenizer = load_tokenizer(model_args) @@ -36,15 +38,20 @@ class VllmEngine(BaseEngine): model=model_args.model_name_or_path, trust_remote_code=True, download_dir=model_args.cache_dir, - dtype=str(load_dtype).split(".")[-1], + dtype=infer_dtype, max_model_len=model_args.vllm_maxlen, tensor_parallel_size=get_device_count() or 1, gpu_memory_utilization=model_args.vllm_gpu_util, disable_log_stats=True, disable_log_requests=True, enforce_eager=model_args.vllm_enforce_eager, + enable_lora=model_args.adapter_name_or_path is not None, ) self.model = AsyncLLMEngine.from_engine_args(engine_args) + if model_args.adapter_name_or_path is not None: + self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0]) + else: + self.lora_request = None async def _generate( self, @@ -98,7 +105,11 @@ class VllmEngine(BaseEngine): skip_special_tokens=True, ) result_generator = self.model.generate( - prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids + prompt=None, + sampling_params=sampling_params, + request_id=request_id, + prompt_token_ids=prompt_ids, + lora_request=self.lora_request, ) return result_generator diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index a7d0a17f..c922dc47 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -308,15 +308,15 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: if finetuning_args.stage != "sft": raise ValueError("vLLM engine only supports auto-regressive models.") - if model_args.adapter_name_or_path is not None: - raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.") - if model_args.quantization_bit is not None: - raise ValueError("vLLM engine does not support quantization.") + raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).") if model_args.rope_scaling is not None: raise ValueError("vLLM engine does not support RoPE scaling.") + if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: + raise ValueError("vLLM only accepts a single adapter. Merge them first.") + _verify_model_args(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args) diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index ee28603e..82e7b7f1 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -31,7 +31,10 @@ class WebChatModel(ChatModel): if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model model_name_or_path = os.environ.get("DEMO_MODEL") template = os.environ.get("DEMO_TEMPLATE") - super().__init__(dict(model_name_or_path=model_name_or_path, template=template)) + infer_backend = os.environ.get("DEMO_BACKEND", "huggingface") + super().__init__( + dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend) + ) @property def loaded(self) -> bool: