From fcca3b0b0d81ab77a777f3b2203232ad2de52296 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 28 Apr 2025 01:59:53 +0800 Subject: [PATCH] [data] fix minicpmo vllm infer (#7870) --- src/llamafactory/data/mm_plugin.py | 217 ++++++++++++++--------------- 1 file changed, 108 insertions(+), 109 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 5e32fab4..a83dab20 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -898,115 +898,6 @@ class LlavaNextVideoPlugin(BasePlugin): @dataclass class MiniCPMVPlugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 - messages = deepcopy(messages) - image_processor: BaseImageProcessor = getattr(processor, "image_processor") - mm_inputs = {} - audio_inputs = {} - if len(images) != 0 and len(videos) != 0: - raise ValueError("MiniCPM-V model does not support input images and videos at the same time.") - - if len(videos) != 0: - max_slice_nums = 2 - use_image_id = False - mm_inputs = self._get_mm_inputs([], videos, [], processor) - else: - max_slice_nums = image_processor.max_slice_nums - use_image_id = image_processor.use_image_id - - for i, message in enumerate(messages): - content = message["content"] - while IMAGE_PLACEHOLDER in content: - if num_image_tokens >= len(images): - raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") - - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) - num_image_tokens += 1 - - while VIDEO_PLACEHOLDER in content: - if num_video_tokens >= len(videos): - raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") - - video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 - content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) - num_video_tokens += 1 - - while AUDIO_PLACEHOLDER in content: - if num_audio_tokens >= len(audios): - raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.") - - content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1) - num_audio_tokens += 1 - - message["content"] = content.replace("{{image}}", "(./)").replace( - "{{audio}}", "()" - ) - - if num_image_tokens > 0: - mm_inputs = self._get_mm_inputs(images, [], [], processor) - - if num_audio_tokens > 0: - audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True) - - if mm_inputs: - pattern = "(./)" - image_sizes = mm_inputs["image_sizes"] - idx = 0 - for index, message in enumerate(messages): - text = message["content"] - image_tags = re.findall(pattern, text) - text_chunks = text.split(pattern) - final_text = "" - for i in range(len(image_tags)): - final_text = ( - final_text - + text_chunks[i] - + image_processor.get_slice_image_placeholder( - image_sizes[0][idx], idx, max_slice_nums, use_image_id - ) - ) - idx += 1 - - final_text += text_chunks[-1] - messages[index]["content"] = final_text - - if audio_inputs: - pattern = "()" - idx = 0 - for index, message in enumerate(messages): - text = message["content"] - audio_tags = re.findall(pattern, text) - text_chunks = text.split(pattern) - final_text = "" - for i in range(len(audio_tags)): - audio_placeholder = audio_inputs["audio_phs"][0][idx] - final_text = final_text + text_chunks[i] + audio_placeholder - idx += 1 - - final_text += text_chunks[-1] - messages[index]["content"] = final_text - - if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") - - if len(videos) != num_video_tokens: - raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") - - if len(audios) != num_audio_tokens: - raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.") - - return messages - @override def _get_mm_inputs( self, @@ -1077,6 +968,114 @@ class MiniCPMVPlugin(BasePlugin): return mm_inputs + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + mm_inputs, audio_inputs = {}, {} + if len(images) != 0 and len(videos) != 0: + raise ValueError("MiniCPM-V model does not support input images and videos at the same time.") + + if len(videos) != 0: + max_slice_nums = 2 + use_image_id = False + mm_inputs = self._get_mm_inputs([], videos, [], processor) + else: + max_slice_nums = image_processor.max_slice_nums + use_image_id = image_processor.use_image_id + + for i, message in enumerate(messages): + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if num_image_tokens >= len(images): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + if num_video_tokens >= len(videos): + raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") + + video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 + content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) + num_video_tokens += 1 + + while AUDIO_PLACEHOLDER in content: + if num_audio_tokens >= len(audios): + raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.") + + content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1) + num_audio_tokens += 1 + + message["content"] = content.replace("{{image}}", "(./)").replace( + "{{audio}}", "()" + ) + + if len(images): + mm_inputs = self._get_mm_inputs(images, [], [], processor) + + if len(audios): + audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True) + + if self.expand_mm_tokens and mm_inputs: + pattern = "(./)" + image_sizes = mm_inputs["image_sizes"] + idx = 0 + for index, message in enumerate(messages): + text = message["content"] + image_tags = re.findall(pattern, text) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(image_tags)): + final_text = ( + final_text + + text_chunks[i] + + image_processor.get_slice_image_placeholder( + image_sizes[0][idx], idx, max_slice_nums, use_image_id + ) + ) + idx += 1 + + final_text += text_chunks[-1] + messages[index]["content"] = final_text + + if self.expand_mm_tokens and audio_inputs: + pattern = "()" + idx = 0 + for index, message in enumerate(messages): + text = message["content"] + audio_tags = re.findall(pattern, text) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(audio_tags)): + audio_placeholder = audio_inputs["audio_phs"][0][idx] + final_text = final_text + text_chunks[i] + audio_placeholder + idx += 1 + + final_text += text_chunks[-1] + messages[index]["content"] = final_text + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + if len(videos) != num_video_tokens: + raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") + + if len(audios) != num_audio_tokens: + raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.") + + return messages + @override def get_mm_inputs( self,