fix qwen2vl vllm infer

Former-commit-id: fa50fc470e46010214f0e509ef75a6c167caac7f
This commit is contained in:
hiyouga 2024-11-24 23:19:35 +08:00
parent 7ed5a712f8
commit dcc67ac1a5

View File

@ -112,7 +112,15 @@ class VllmEngine(BaseEngine):
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
image_str = "<|vision_start|>" + self.template.mm_plugin.image_token + "<|vision_end|>"
else:
image_str = self.template.mm_plugin.image_token
paired_messages = [
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
for message in messages
] + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)