From ae32c148d1adcbe86701babf8b3d5b64f262039b Mon Sep 17 00:00:00 2001
From: Zhangchi Feng <64362896+BUAADreamer@users.noreply.github.com>
Date: Tue, 14 Jan 2025 00:26:19 +0800
Subject: [PATCH] Support new features of MiniCPM-V (#6626)
* fix template name
* tiny fix
* support minicpm-o-2.6
Former-commit-id: 53034a61c7654358f46916cbc370910fb2aeff3b
---
README.md | 2 +-
README_zh.md | 2 +-
setup.py | 10 +
src/llamafactory/data/collator.py | 5 +-
src/llamafactory/data/mm_plugin.py | 302 +++++++++----------
src/llamafactory/data/template.py | 22 +-
src/llamafactory/extras/constants.py | 13 +-
src/llamafactory/model/model_utils/visual.py | 1 +
8 files changed, 189 insertions(+), 168 deletions(-)
diff --git a/README.md b/README.md
index 4327c951..df9a254f 100644
--- a/README.md
+++ b/README.md
@@ -209,7 +209,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
-| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_v |
+| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_v |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
diff --git a/README_zh.md b/README_zh.md
index bf4eb4d3..c9f28c49 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -210,7 +210,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
-| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_v |
+| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_v |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
diff --git a/setup.py b/setup.py
index 10d93551..75d2b7e3 100644
--- a/setup.py
+++ b/setup.py
@@ -59,6 +59,16 @@ extra_require = {
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"],
+ "minicpm_v": [
+ "soundfile",
+ "torchvision",
+ "torchaudio",
+ "vector_quantize_pytorch",
+ "vocos",
+ "msgpack",
+ "referencing",
+ "jsonschema_specifications",
+ ],
"modelscope": ["modelscope"],
"openmind": ["openmind"],
"swanlab": ["swanlab"],
diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py
index 01742bdc..c14f17e0 100644
--- a/src/llamafactory/data/collator.py
+++ b/src/llamafactory/data/collator.py
@@ -153,9 +153,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features = features.data # use default_collate() instead of BatchEncoding.to()
if "image_bound" in features: # for minicpmv inputs
- features["position_ids"] = (
- torch.arange(features["input_ids"].size(1)).long().unsqueeze(0).expand_as(features["input_ids"])
- )
+ bsz, seq_length = features["input_ids"].shape
+ features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
return {"data": features, "labels": features["labels"]}
return features
diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py
index 7cae556a..e4447ee5 100644
--- a/src/llamafactory/data/mm_plugin.py
+++ b/src/llamafactory/data/mm_plugin.py
@@ -254,156 +254,6 @@ class BasePlugin:
return {}
-class CpmVPlugin(BasePlugin):
- @override
- def process_messages(
- self,
- messages: Sequence[Dict[str, str]],
- images: Sequence["ImageInput"],
- videos: Sequence["VideoInput"],
- processor: Optional["ProcessorMixin"],
- ) -> List[Dict[str, str]]:
- self._validate_input(images, videos)
- num_image_tokens = 0
- num_video_tokens = 0
- messages = deepcopy(messages)
- image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
- mm_inputs = {}
- if len(images) != 0 and len(videos) != 0:
- raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
-
- if len(videos) != 0:
- max_slice_nums = 2
- use_image_id = False
- mm_inputs = self._get_mm_inputs([], videos, processor)
- else:
- max_slice_nums = image_processor.max_slice_nums
- use_image_id = image_processor.use_image_id
-
- for message in messages:
- content = message["content"]
- while IMAGE_PLACEHOLDER in content:
- num_image_tokens += 1
- content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
-
- while VIDEO_PLACEHOLDER in content:
- video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
- content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
- num_video_tokens += 1
-
- message["content"] = content.replace("{{image}}", "(./)")
-
- if num_image_tokens > 0:
- mm_inputs = self._get_mm_inputs(images, [], processor)
-
- if mm_inputs:
- pattern = "(./)"
- image_sizes = mm_inputs["image_sizes"]
-
- for index, message in enumerate(messages):
- text = message["content"]
- image_tags = re.findall(pattern, text)
- text_chunks = text.split(pattern)
- final_text = ""
- for i in range(len(image_tags)):
- final_text = (
- final_text
- + text_chunks[i]
- + image_processor.get_slice_image_placeholder(
- image_sizes[0][i], i, max_slice_nums, use_image_id
- )
- )
- final_text += text_chunks[-1]
- messages[index]["content"] = final_text
-
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
- if len(videos) != num_video_tokens:
- raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
-
- return messages
-
- @override
- def _get_mm_inputs(
- self,
- images: Sequence["ImageInput"],
- videos: Sequence["VideoInput"],
- processor: "ProcessorMixin",
- **kwargs,
- ) -> Dict[str, "torch.Tensor"]:
- image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
-
- mm_inputs = {}
- if len(images) != 0:
- images = self._regularize_images(
- images,
- image_resolution=getattr(processor, "image_resolution", 512 * 512),
- )
- if "valid_image_nums_ls" in kwargs:
- valid_image_nums_ls = kwargs["valid_image_nums_ls"]
- new_images = []
- idx = 0
- for valid_image_nums in valid_image_nums_ls:
- new_images.append(images[idx : idx + valid_image_nums])
- idx += valid_image_nums
-
- images = new_images
-
- image_inputs = image_processor(
- images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
- )
- mm_inputs.update(image_inputs)
-
- if len(videos) != 0:
- videos = self._regularize_videos(
- videos,
- image_resolution=getattr(processor, "video_resolution", 128 * 128),
- video_fps=getattr(processor, "video_fps", 2.0),
- video_maxlen=getattr(processor, "video_maxlen", 64),
- )
- video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
- mm_inputs.update(video_inputs)
-
- return mm_inputs
-
- @override
- def get_mm_inputs(
- self,
- images: Sequence["ImageInput"],
- videos: Sequence["VideoInput"],
- imglens: Sequence[int],
- vidlens: Sequence[int],
- batch_ids: Sequence[List[int]],
- processor: Optional["ProcessorMixin"],
- ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
- self._validate_input(images, videos)
- image_bounds_list = []
- valid_image_nums_ls = []
- for input_ids in batch_ids:
- input_ids_ = torch.tensor(input_ids)
- start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
- input_ids_ == processor.tokenizer.slice_start_id
- )
- end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
- image_start_tokens = torch.where(start_cond)[0]
- image_start_tokens += 1
- image_end_tokens = torch.where(end_cond)[0]
- valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
- valid_image_nums_ls.append(valid_image_nums)
- image_bounds = torch.hstack(
- [
- image_start_tokens[:valid_image_nums].unsqueeze(-1),
- image_end_tokens[:valid_image_nums].unsqueeze(-1),
- ]
- )
- image_bounds_list.append(image_bounds)
-
- mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
- mm_inputs.update({"image_bound": image_bounds_list})
- return mm_inputs
-
-
class LlavaPlugin(BasePlugin):
@override
def process_messages(
@@ -567,6 +417,156 @@ class LlavaNextVideoPlugin(BasePlugin):
return self._get_mm_inputs(images, videos, processor)
+class MiniCPMVPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: Sequence[Dict[str, str]],
+ images: Sequence["ImageInput"],
+ videos: Sequence["VideoInput"],
+ processor: Optional["ProcessorMixin"],
+ ) -> List[Dict[str, str]]:
+ self._validate_input(images, videos)
+ num_image_tokens = 0
+ num_video_tokens = 0
+ messages = deepcopy(messages)
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
+ mm_inputs = {}
+ if len(images) != 0 and len(videos) != 0:
+ raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
+
+ if len(videos) != 0:
+ max_slice_nums = 2
+ use_image_id = False
+ mm_inputs = self._get_mm_inputs([], videos, processor)
+ else:
+ max_slice_nums = image_processor.max_slice_nums
+ use_image_id = image_processor.use_image_id
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
+ num_image_tokens += 1
+
+ while VIDEO_PLACEHOLDER in content:
+ video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
+ content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
+ num_video_tokens += 1
+
+ message["content"] = content.replace("{{image}}", "(./)")
+
+ if num_image_tokens > 0:
+ mm_inputs = self._get_mm_inputs(images, [], processor)
+
+ if mm_inputs:
+ pattern = "(./)"
+ image_sizes = mm_inputs["image_sizes"]
+
+ for index, message in enumerate(messages):
+ text = message["content"]
+ image_tags = re.findall(pattern, text)
+ text_chunks = text.split(pattern)
+ final_text = ""
+ for i in range(len(image_tags)):
+ final_text = (
+ final_text
+ + text_chunks[i]
+ + image_processor.get_slice_image_placeholder(
+ image_sizes[0][i], i, max_slice_nums, use_image_id
+ )
+ )
+
+ final_text += text_chunks[-1]
+ messages[index]["content"] = final_text
+
+ if len(images) != num_image_tokens:
+ raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
+
+ if len(videos) != num_video_tokens:
+ raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
+
+ return messages
+
+ @override
+ def _get_mm_inputs(
+ self,
+ images: Sequence["ImageInput"],
+ videos: Sequence["VideoInput"],
+ processor: "ProcessorMixin",
+ **kwargs,
+ ) -> Dict[str, "torch.Tensor"]:
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
+ mm_inputs = {}
+ if len(images) != 0:
+ images = self._regularize_images(
+ images,
+ image_resolution=getattr(processor, "image_resolution", 512 * 512),
+ )
+ if "valid_image_nums_ls" in kwargs:
+ valid_image_nums_ls = kwargs["valid_image_nums_ls"]
+ new_images = []
+ idx = 0
+ for valid_image_nums in valid_image_nums_ls:
+ new_images.append(images[idx : idx + valid_image_nums])
+ idx += valid_image_nums
+
+ images = new_images
+
+ image_inputs = image_processor(
+ images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
+ )
+ mm_inputs.update(image_inputs)
+
+ if len(videos) != 0:
+ videos = self._regularize_videos(
+ videos,
+ image_resolution=getattr(processor, "video_resolution", 128 * 128),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 64),
+ )
+ video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
+ mm_inputs.update(video_inputs)
+
+ return mm_inputs
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: Sequence["ImageInput"],
+ videos: Sequence["VideoInput"],
+ imglens: Sequence[int],
+ vidlens: Sequence[int],
+ batch_ids: Sequence[List[int]],
+ processor: Optional["ProcessorMixin"],
+ ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
+ self._validate_input(images, videos)
+ image_bounds_list = []
+ valid_image_nums_ls = []
+ for input_ids in batch_ids:
+ input_ids_ = torch.tensor(input_ids)
+ start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
+ input_ids_ == processor.tokenizer.slice_start_id
+ )
+ end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
+ image_start_tokens = torch.where(start_cond)[0]
+ image_start_tokens += 1
+ image_end_tokens = torch.where(end_cond)[0]
+ valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
+ valid_image_nums_ls.append(valid_image_nums)
+ image_bounds = torch.hstack(
+ [
+ image_start_tokens[:valid_image_nums].unsqueeze(-1),
+ image_end_tokens[:valid_image_nums].unsqueeze(-1),
+ ]
+ )
+ image_bounds_list.append(image_bounds)
+
+ mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
+ mm_inputs.update({"image_bound": image_bounds_list})
+ return mm_inputs
+
+
class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
@@ -945,10 +945,10 @@ class MllamaPlugin(BasePlugin):
PLUGINS = {
"base": BasePlugin,
- "cpm_v": CpmVPlugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
+ "minicpm_v": MiniCPMVPlugin,
"paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,
"qwen2_vl": Qwen2vlPlugin,
diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index 05eb0cda..c9340afc 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -576,17 +576,6 @@ _register_template(
)
-# copied from chatml template
-_register_template(
- name="cpm_v",
- format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
- format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
- format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
- stop_words=["<|im_end|>"],
- mm_plugin=get_mm_plugin(name="cpm_v", image_token="", video_token="