diff --git a/setup.py b/setup.py index 45e73343..23f532e7 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ extra_require = { "metrics": ["nltk", "jieba", "rouge-chinese"], "deepspeed": ["deepspeed>=0.10.0,<=0.14.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"], - "vllm": ["vllm>=0.4.1"], + "vllm": ["vllm>=0.4.3"], "galore": ["galore-torch"], "badam": ["badam"], "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 3310a864..8a067754 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -158,12 +158,10 @@ class VllmEngine(BaseEngine): ) result_generator = self.model.generate( - prompt=None, + inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, sampling_params=sampling_params, request_id=request_id, - prompt_token_ids=prompt_ids, lora_request=self.lora_request, - multi_modal_data=multi_modal_data, ) return result_generator diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index b3c673be..ff1fbf5d 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -94,7 +94,7 @@ def _check_extra_dependencies( require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") if model_args.infer_backend == "vllm": - require_version("vllm>=0.4.1", "To fix: pip install vllm>=0.4.1") + require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3") if finetuning_args.use_galore: require_version("galore_torch", "To fix: pip install galore_torch")