[data] fix minicpmo vllm infer (#7870)

This commit is contained in:
hoshi-hiyouga 2025-04-28 01:59:53 +08:00 committed by GitHub
parent 035e98035c
commit fcca3b0b0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -898,115 +898,6 @@ class LlavaNextVideoPlugin(BasePlugin):
@dataclass
class MiniCPMVPlugin(BasePlugin):
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {}
audio_inputs = {}
if len(images) != 0 and len(videos) != 0:
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
if len(videos) != 0:
max_slice_nums = 2
use_image_id = False
mm_inputs = self._get_mm_inputs([], videos, [], processor)
else:
max_slice_nums = image_processor.max_slice_nums
use_image_id = image_processor.use_image_id
for i, message in enumerate(messages):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1
while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
num_audio_tokens += 1
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
"{{audio}}", "(<audio>./</audio>)"
)
if num_image_tokens > 0:
mm_inputs = self._get_mm_inputs(images, [], [], processor)
if num_audio_tokens > 0:
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
if mm_inputs:
pattern = "(<image>./</image>)"
image_sizes = mm_inputs["image_sizes"]
idx = 0
for index, message in enumerate(messages):
text = message["content"]
image_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(image_tags)):
final_text = (
final_text
+ text_chunks[i]
+ image_processor.get_slice_image_placeholder(
image_sizes[0][idx], idx, max_slice_nums, use_image_id
)
)
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if audio_inputs:
pattern = "(<audio>./</audio>)"
idx = 0
for index, message in enumerate(messages):
text = message["content"]
audio_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(audio_tags)):
audio_placeholder = audio_inputs["audio_phs"][0][idx]
final_text = final_text + text_chunks[i] + audio_placeholder
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
return messages
@override
def _get_mm_inputs(
self,
@ -1077,6 +968,114 @@ class MiniCPMVPlugin(BasePlugin):
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs, audio_inputs = {}, {}
if len(images) != 0 and len(videos) != 0:
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
if len(videos) != 0:
max_slice_nums = 2
use_image_id = False
mm_inputs = self._get_mm_inputs([], videos, [], processor)
else:
max_slice_nums = image_processor.max_slice_nums
use_image_id = image_processor.use_image_id
for i, message in enumerate(messages):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1
while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
num_audio_tokens += 1
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
"{{audio}}", "(<audio>./</audio>)"
)
if len(images):
mm_inputs = self._get_mm_inputs(images, [], [], processor)
if len(audios):
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
if self.expand_mm_tokens and mm_inputs:
pattern = "(<image>./</image>)"
image_sizes = mm_inputs["image_sizes"]
idx = 0
for index, message in enumerate(messages):
text = message["content"]
image_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(image_tags)):
final_text = (
final_text
+ text_chunks[i]
+ image_processor.get_slice_image_placeholder(
image_sizes[0][idx], idx, max_slice_nums, use_image_id
)
)
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if self.expand_mm_tokens and audio_inputs:
pattern = "(<audio>./</audio>)"
idx = 0
for index, message in enumerate(messages):
text = message["content"]
audio_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(audio_tags)):
audio_placeholder = audio_inputs["audio_phs"][0][idx]
final_text = final_text + text_chunks[i] + audio_placeholder
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,