From dcc67ac1a5a825a26ae6bb4c272f2019aed7c0ea Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sun, 24 Nov 2024 23:19:35 +0800 Subject: [PATCH] fix qwen2vl vllm infer Former-commit-id: fa50fc470e46010214f0e509ef75a6c167caac7f --- src/llamafactory/chat/vllm_engine.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index c696978e..99542d23 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -112,7 +112,15 @@ class VllmEngine(BaseEngine): if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] - paired_messages = messages + [{"role": "assistant", "content": ""}] + if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution + image_str = "<|vision_start|>" + self.template.mm_plugin.image_token + "<|vision_end|>" + else: + image_str = self.template.mm_plugin.image_token + + paired_messages = [ + {"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)} + for message in messages + ] + [{"role": "assistant", "content": ""}] system = system or self.generating_args["default_system"] prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_length = len(prompt_ids)