diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py
index 5e32fab4..a83dab20 100644
--- a/src/llamafactory/data/mm_plugin.py
+++ b/src/llamafactory/data/mm_plugin.py
@@ -898,115 +898,6 @@ class LlavaNextVideoPlugin(BasePlugin):
@dataclass
class MiniCPMVPlugin(BasePlugin):
- @override
- def process_messages(
- self,
- messages: list[dict[str, str]],
- images: list["ImageInput"],
- videos: list["VideoInput"],
- audios: list["AudioInput"],
- processor: Optional["MMProcessor"],
- ) -> list[dict[str, str]]:
- self._validate_input(processor, images, videos, audios)
- num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
- messages = deepcopy(messages)
- image_processor: BaseImageProcessor = getattr(processor, "image_processor")
- mm_inputs = {}
- audio_inputs = {}
- if len(images) != 0 and len(videos) != 0:
- raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
-
- if len(videos) != 0:
- max_slice_nums = 2
- use_image_id = False
- mm_inputs = self._get_mm_inputs([], videos, [], processor)
- else:
- max_slice_nums = image_processor.max_slice_nums
- use_image_id = image_processor.use_image_id
-
- for i, message in enumerate(messages):
- content = message["content"]
- while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
- content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
- num_image_tokens += 1
-
- while VIDEO_PLACEHOLDER in content:
- if num_video_tokens >= len(videos):
- raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
-
- video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
- content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
- num_video_tokens += 1
-
- while AUDIO_PLACEHOLDER in content:
- if num_audio_tokens >= len(audios):
- raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
-
- content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
- num_audio_tokens += 1
-
- message["content"] = content.replace("{{image}}", "(./)").replace(
- "{{audio}}", "()"
- )
-
- if num_image_tokens > 0:
- mm_inputs = self._get_mm_inputs(images, [], [], processor)
-
- if num_audio_tokens > 0:
- audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
-
- if mm_inputs:
- pattern = "(./)"
- image_sizes = mm_inputs["image_sizes"]
- idx = 0
- for index, message in enumerate(messages):
- text = message["content"]
- image_tags = re.findall(pattern, text)
- text_chunks = text.split(pattern)
- final_text = ""
- for i in range(len(image_tags)):
- final_text = (
- final_text
- + text_chunks[i]
- + image_processor.get_slice_image_placeholder(
- image_sizes[0][idx], idx, max_slice_nums, use_image_id
- )
- )
- idx += 1
-
- final_text += text_chunks[-1]
- messages[index]["content"] = final_text
-
- if audio_inputs:
- pattern = "()"
- idx = 0
- for index, message in enumerate(messages):
- text = message["content"]
- audio_tags = re.findall(pattern, text)
- text_chunks = text.split(pattern)
- final_text = ""
- for i in range(len(audio_tags)):
- audio_placeholder = audio_inputs["audio_phs"][0][idx]
- final_text = final_text + text_chunks[i] + audio_placeholder
- idx += 1
-
- final_text += text_chunks[-1]
- messages[index]["content"] = final_text
-
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
- if len(videos) != num_video_tokens:
- raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
-
- if len(audios) != num_audio_tokens:
- raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
-
- return messages
-
@override
def _get_mm_inputs(
self,
@@ -1077,6 +968,114 @@ class MiniCPMVPlugin(BasePlugin):
return mm_inputs
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["MMProcessor"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
+ messages = deepcopy(messages)
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
+ mm_inputs, audio_inputs = {}, {}
+ if len(images) != 0 and len(videos) != 0:
+ raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
+
+ if len(videos) != 0:
+ max_slice_nums = 2
+ use_image_id = False
+ mm_inputs = self._get_mm_inputs([], videos, [], processor)
+ else:
+ max_slice_nums = image_processor.max_slice_nums
+ use_image_id = image_processor.use_image_id
+
+ for i, message in enumerate(messages):
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ if num_image_tokens >= len(images):
+ raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
+
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
+ num_image_tokens += 1
+
+ while VIDEO_PLACEHOLDER in content:
+ if num_video_tokens >= len(videos):
+ raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
+
+ video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
+ content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
+ num_video_tokens += 1
+
+ while AUDIO_PLACEHOLDER in content:
+ if num_audio_tokens >= len(audios):
+ raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
+
+ content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
+ num_audio_tokens += 1
+
+ message["content"] = content.replace("{{image}}", "(./)").replace(
+ "{{audio}}", "()"
+ )
+
+ if len(images):
+ mm_inputs = self._get_mm_inputs(images, [], [], processor)
+
+ if len(audios):
+ audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
+
+ if self.expand_mm_tokens and mm_inputs:
+ pattern = "(./)"
+ image_sizes = mm_inputs["image_sizes"]
+ idx = 0
+ for index, message in enumerate(messages):
+ text = message["content"]
+ image_tags = re.findall(pattern, text)
+ text_chunks = text.split(pattern)
+ final_text = ""
+ for i in range(len(image_tags)):
+ final_text = (
+ final_text
+ + text_chunks[i]
+ + image_processor.get_slice_image_placeholder(
+ image_sizes[0][idx], idx, max_slice_nums, use_image_id
+ )
+ )
+ idx += 1
+
+ final_text += text_chunks[-1]
+ messages[index]["content"] = final_text
+
+ if self.expand_mm_tokens and audio_inputs:
+ pattern = "()"
+ idx = 0
+ for index, message in enumerate(messages):
+ text = message["content"]
+ audio_tags = re.findall(pattern, text)
+ text_chunks = text.split(pattern)
+ final_text = ""
+ for i in range(len(audio_tags)):
+ audio_placeholder = audio_inputs["audio_phs"][0][idx]
+ final_text = final_text + text_chunks[i] + audio_placeholder
+ idx += 1
+
+ final_text += text_chunks[-1]
+ messages[index]["content"] = final_text
+
+ if len(images) != num_image_tokens:
+ raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
+
+ if len(videos) != num_video_tokens:
+ raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
+
+ if len(audios) != num_audio_tokens:
+ raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
+
+ return messages
+
@override
def get_mm_inputs(
self,