diff --git a/requirements.txt b/requirements.txt
index b47643e3..07034651 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-transformers>=4.41.2,<=4.46.1
+transformers>=4.41.2
datasets>=2.16.0,<=3.1.0
accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0
diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py
index f360f0f5..90abf34c 100644
--- a/src/llamafactory/data/collator.py
+++ b/src/llamafactory/data/collator.py
@@ -149,6 +149,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
+ if "image_bound" in features:
+ input_ids, position_ids = features['input_ids'], features['position_ids']
+ features['position_ids'] = F.pad(position_ids, (0, input_ids.shape[-1] - position_ids.shape[-1]))
+ new_features = {}
+ new_features.update({"data": features})
+ new_features.update(features)
+ features = new_features
return features
diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py
index 4e1f418a..0fbb5b1d 100644
--- a/src/llamafactory/data/mm_plugin.py
+++ b/src/llamafactory/data/mm_plugin.py
@@ -2,6 +2,7 @@ import math
from copy import deepcopy
from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
+import re
import numpy as np
import torch
@@ -249,6 +250,130 @@ class BasePlugin:
return {}
+class CpmOPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: Sequence[Dict[str, str]],
+ images: Sequence["ImageInput"],
+ videos: Sequence["VideoInput"],
+ processor: Optional["ProcessorMixin"],
+ ) -> List[Dict[str, str]]:
+ self._validate_input(images, videos)
+ num_image_tokens = 0
+ messages = deepcopy(messages)
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ num_image_tokens += 1
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
+
+ message["content"] = content.replace("{{image}}", "(./)")
+
+ if num_image_tokens>0:
+ mm_inputs = self._get_mm_inputs(images, videos, processor)
+
+ pattern = "(./)"
+ images, image_sizes, tgt_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"]
+
+ input_ids_list = []
+ image_bounds_list = []
+ image_index = 0
+ for index, message in enumerate(messages):
+ text = message['content']
+ image_tags = re.findall(pattern, text)
+ text_chunks = text.split(pattern)
+ final_text = ""
+ for i in range(len(image_tags)):
+ final_text = final_text + text_chunks[i] + \
+ image_processor.get_slice_image_placeholder(
+ image_sizes[image_index][i],
+ i,
+ image_processor.max_slice_nums,
+ image_processor.use_image_id,
+ )
+ image_index += 1
+ final_text += text_chunks[-1]
+ messages[index]['content'] = final_text
+ # print(messages)
+
+ if len(images) != num_image_tokens:
+ raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
+
+ return messages
+
+ @override
+ def _get_mm_inputs(
+ self,
+ images: Sequence["ImageInput"],
+ videos: Sequence["VideoInput"],
+ processor: "ProcessorMixin",
+ ) -> Dict[str, "torch.Tensor"]:
+
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
+
+ mm_inputs = {}
+
+ if len(images) != 0:
+ images = self._regularize_images(
+ images,
+ image_resolution=getattr(processor, "image_resolution", 512 * 512),
+ )
+ image_inputs = image_processor(images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt")
+ mm_inputs.update(image_inputs)
+
+ if len(videos) != 0:
+ videos = self._regularize_videos(
+ videos,
+ image_resolution=getattr(processor, "video_resolution", 128 * 128),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 64),
+ )
+
+ return mm_inputs
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: Sequence["ImageInput"],
+ videos: Sequence["VideoInput"],
+ imglens: Sequence[int],
+ vidlens: Sequence[int],
+ batch_ids: Sequence[List[int]],
+ processor: Optional["ProcessorMixin"],
+ ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
+ self._validate_input(images, videos)
+ mm_inputs = self._get_mm_inputs(images, videos, processor)
+ image_bounds_list = []
+ position_ids = []
+ for input_ids in batch_ids:
+ input_ids_ = torch.tensor(input_ids)
+ start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (input_ids_ == processor.tokenizer.slice_start_id)
+ end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
+ image_start_tokens = torch.where(start_cond)[0]
+ image_start_tokens += 1
+ image_end_tokens = torch.where(end_cond)[0]
+ valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
+ image_bounds = torch.hstack(
+ [
+ image_start_tokens[:valid_image_nums].unsqueeze(-1),
+ image_end_tokens[:valid_image_nums].unsqueeze(-1),
+ ]
+ )
+ image_bounds_list.append(image_bounds)
+ position_ids_ = list(range(input_ids_.size(0)))
+ # print(input_ids_.shape, len(position_ids_)
+ position_ids.append(position_ids_)
+ position_ids = torch.tensor(position_ids, dtype=torch.int64)
+ mm_inputs.update({
+ "image_bound": image_bounds_list,
+ "position_ids": position_ids,
+ })
+ return mm_inputs
+
+
class LlavaPlugin(BasePlugin):
@override
def process_messages(
@@ -790,6 +915,7 @@ class MllamaPlugin(BasePlugin):
PLUGINS = {
"base": BasePlugin,
+ "cpm_o": CpmOPlugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index 5768cf7b..64911c58 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -583,6 +583,22 @@ _register_template(
)
+_register_template(
+ name="cpm_o",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
+ format_observation=StringFormatter(
+ slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"]
+ ),
+ format_tools=ToolFormatter(tool_format="qwen"),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ default_system="You are a helpful assistant.",
+ stop_words=["<|im_end|>"],
+ mm_plugin=get_mm_plugin(name="cpm_o", image_token="", video_token="