diff --git a/README.md b/README.md
index 739a13af..7b70a89a 100644
--- a/README.md
+++ b/README.md
@@ -209,6 +209,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
+| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_o |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
diff --git a/README_zh.md b/README_zh.md
index e21560d2..80903fcb 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -210,6 +210,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
+| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_o |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py
index f360f0f5..dfd853ca 100644
--- a/src/llamafactory/data/collator.py
+++ b/src/llamafactory/data/collator.py
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch
import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_sequence
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
@@ -101,7 +102,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_vidlens.append(len(videos))
batch_input_ids.append(feature["input_ids"])
- if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case
+ if (
+ self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
+ ): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
@@ -150,6 +153,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
+ if "image_bound" in features: # for minicpmv inputs
+ features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
+ features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0)
+ new_features = {"data": features}
+ new_features.update({"labels": features["labels"]})
+ features = new_features
+
return features
diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py
index a8c46d11..909ce7c0 100644
--- a/src/llamafactory/data/mm_plugin.py
+++ b/src/llamafactory/data/mm_plugin.py
@@ -1,4 +1,5 @@
import math
+import re
from copy import deepcopy
from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
@@ -253,6 +254,160 @@ class BasePlugin:
return {}
+class CpmOPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: Sequence[Dict[str, str]],
+ images: Sequence["ImageInput"],
+ videos: Sequence["VideoInput"],
+ processor: Optional["ProcessorMixin"],
+ ) -> List[Dict[str, str]]:
+ self._validate_input(images, videos)
+ num_image_tokens = 0
+ num_video_tokens = 0
+ messages = deepcopy(messages)
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
+ mm_inputs = {}
+
+ if len(videos) != 0:
+ assert len(images) == 0, "Only support video and image sft seperately"
+ max_slice_nums = 2
+ use_image_id = False
+ mm_inputs = self._get_mm_inputs([], videos, processor)
+ else:
+ max_slice_nums = image_processor.max_slice_nums
+ use_image_id = image_processor.use_image_id
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ num_image_tokens += 1
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
+
+ while VIDEO_PLACEHOLDER in content:
+ num_video_tokens += 1
+ content = content.replace(
+ VIDEO_PLACEHOLDER, "{{image}}" * len(mm_inputs["pixel_values"][num_video_tokens - 1]), 1
+ )
+
+ message["content"] = content.replace("{{image}}", "(./)")
+
+ if num_image_tokens > 0:
+ mm_inputs = self._get_mm_inputs(images, [], processor)
+
+ if mm_inputs:
+ pattern = "(./)"
+ image_sizes = mm_inputs["image_sizes"]
+
+ for index, message in enumerate(messages):
+ text = message["content"]
+ image_tags = re.findall(pattern, text)
+ text_chunks = text.split(pattern)
+ final_text = ""
+ for i in range(len(image_tags)):
+ final_text = (
+ final_text
+ + text_chunks[i]
+ + image_processor.get_slice_image_placeholder(
+ image_sizes[0][i],
+ i,
+ max_slice_nums,
+ use_image_id,
+ )
+ )
+ final_text += text_chunks[-1]
+ messages[index]["content"] = final_text
+
+ if len(images) != num_image_tokens:
+ raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
+
+ if len(videos) != num_video_tokens:
+ raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
+
+ return messages
+
+ @override
+ def _get_mm_inputs(
+ self,
+ images: Sequence["ImageInput"],
+ videos: Sequence["VideoInput"],
+ processor: "ProcessorMixin",
+ **kwargs,
+ ) -> Dict[str, "torch.Tensor"]:
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
+
+ mm_inputs = {}
+
+ if len(images) != 0:
+ images = self._regularize_images(
+ images,
+ image_resolution=getattr(processor, "image_resolution", 512 * 512),
+ )
+ if "valid_image_nums_ls" in kwargs:
+ valid_image_nums_ls = kwargs["valid_image_nums_ls"]
+ new_images = []
+ idx = 0
+ for valid_image_nums in valid_image_nums_ls:
+ new_images.append(images[idx : idx + valid_image_nums])
+ idx += valid_image_nums
+ images = new_images
+
+ image_inputs = image_processor(
+ images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
+ )
+ mm_inputs.update(image_inputs)
+
+ if len(videos) != 0:
+ videos = self._regularize_videos(
+ videos,
+ image_resolution=getattr(processor, "video_resolution", 128 * 128),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 64),
+ )
+ video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
+ mm_inputs.update(video_inputs)
+
+ return mm_inputs
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: Sequence["ImageInput"],
+ videos: Sequence["VideoInput"],
+ imglens: Sequence[int],
+ vidlens: Sequence[int],
+ batch_ids: Sequence[List[int]],
+ processor: Optional["ProcessorMixin"],
+ ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
+ self._validate_input(images, videos)
+ image_bounds_list = []
+ valid_image_nums_ls = []
+
+ for input_ids in batch_ids:
+ input_ids_ = torch.tensor(input_ids)
+ start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
+ input_ids_ == processor.tokenizer.slice_start_id
+ )
+ end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
+ image_start_tokens = torch.where(start_cond)[0]
+ image_start_tokens += 1
+ image_end_tokens = torch.where(end_cond)[0]
+ valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
+ valid_image_nums_ls.append(valid_image_nums)
+ image_bounds = torch.hstack(
+ [
+ image_start_tokens[:valid_image_nums].unsqueeze(-1),
+ image_end_tokens[:valid_image_nums].unsqueeze(-1),
+ ]
+ )
+ image_bounds_list.append(image_bounds)
+
+ mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
+ mm_inputs.update({"image_bound": image_bounds_list})
+ return mm_inputs
+
+
class LlavaPlugin(BasePlugin):
@override
def process_messages(
@@ -794,6 +949,7 @@ class MllamaPlugin(BasePlugin):
PLUGINS = {
"base": BasePlugin,
+ "cpm_o": CpmOPlugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index 682ccb10..58dcf561 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -566,6 +566,16 @@ _register_template(
)
+_register_template(
+ name="cpm_o",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ stop_words=["<|im_end|>"],
+ mm_plugin=get_mm_plugin(name="cpm_o", image_token="", video_token="