Merge pull request #6598 from BUAADreamer/minicpmv

[model] Support MiniCPM-V

Former-commit-id: 6eec50c74dcbcc325ad6258228e19c19b4a03538
This commit is contained in:
hoshi-hiyouga 2025-01-13 15:24:02 +08:00 committed by GitHub
commit 0b47c2a293
7 changed files with 196 additions and 1 deletions

View File

@ -209,6 +209,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_o |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |

View File

@ -210,6 +210,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_o |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |

View File

@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
@ -101,7 +102,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_vidlens.append(len(videos)) batch_vidlens.append(len(videos))
batch_input_ids.append(feature["input_ids"]) batch_input_ids.append(feature["input_ids"])
if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case if (
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}] fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))] fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor) fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
@ -150,6 +153,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if isinstance(features.get("pixel_values"), list): # for pixtral inputs if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to() features = features.data # use default_collate() instead of BatchEncoding.to()
if "image_bound" in features: # for minicpmv inputs
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0)
new_features = {"data": features}
new_features.update({"labels": features["labels"]})
features = new_features
return features return features

View File

@ -1,4 +1,5 @@
import math import math
import re
from copy import deepcopy from copy import deepcopy
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
@ -253,6 +254,160 @@ class BasePlugin:
return {} return {}
class CpmOPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {}
if len(videos) != 0:
assert len(images) == 0, "Only support video and image sft seperately"
max_slice_nums = 2
use_image_id = False
mm_inputs = self._get_mm_inputs([], videos, processor)
else:
max_slice_nums = image_processor.max_slice_nums
use_image_id = image_processor.use_image_id
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(
VIDEO_PLACEHOLDER, "{{image}}" * len(mm_inputs["pixel_values"][num_video_tokens - 1]), 1
)
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
if num_image_tokens > 0:
mm_inputs = self._get_mm_inputs(images, [], processor)
if mm_inputs:
pattern = "(<image>./</image>)"
image_sizes = mm_inputs["image_sizes"]
for index, message in enumerate(messages):
text = message["content"]
image_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(image_tags)):
final_text = (
final_text
+ text_chunks[i]
+ image_processor.get_slice_image_placeholder(
image_sizes[0][i],
i,
max_slice_nums,
use_image_id,
)
)
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
**kwargs,
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512 * 512),
)
if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
new_images = []
idx = 0
for valid_image_nums in valid_image_nums_ls:
new_images.append(images[idx : idx + valid_image_nums])
idx += valid_image_nums
images = new_images
image_inputs = image_processor(
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
)
mm_inputs.update(image_inputs)
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
)
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
mm_inputs.update(video_inputs)
return mm_inputs
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
image_bounds_list = []
valid_image_nums_ls = []
for input_ids in batch_ids:
input_ids_ = torch.tensor(input_ids)
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
input_ids_ == processor.tokenizer.slice_start_id
)
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
valid_image_nums_ls.append(valid_image_nums)
image_bounds = torch.hstack(
[
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
]
)
image_bounds_list.append(image_bounds)
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
mm_inputs.update({"image_bound": image_bounds_list})
return mm_inputs
class LlavaPlugin(BasePlugin): class LlavaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
@ -794,6 +949,7 @@ class MllamaPlugin(BasePlugin):
PLUGINS = { PLUGINS = {
"base": BasePlugin, "base": BasePlugin,
"cpm_o": CpmOPlugin,
"llava": LlavaPlugin, "llava": LlavaPlugin,
"llava_next": LlavaNextPlugin, "llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin, "llava_next_video": LlavaNextVideoPlugin,

View File

@ -566,6 +566,16 @@ _register_template(
) )
_register_template(
name="cpm_o",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="cpm_o", image_token="<image>", video_token="<video>"),
)
# copied from chatml template # copied from chatml template
_register_template( _register_template(
name="dbrx", name="dbrx",

View File

@ -1163,6 +1163,17 @@ register_model_group(
) )
register_model_group(
models={
"MiniCPM-V-2_6-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
},
},
template="cpm_o",
)
register_model_group( register_model_group(
models={ models={
"Mistral-7B-v0.1": { "Mistral-7B-v0.1": {

View File

@ -250,6 +250,12 @@ _register_composite_model(
) )
_register_composite_model(
model_type="minicpmv",
vision_model_keys=["vpm", "apm", "resampler", "tts"],
)
_register_composite_model( _register_composite_model(
model_type="paligemma", model_type="paligemma",
) )