mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
Merge pull request #6598 from BUAADreamer/minicpmv
[model] Support MiniCPM-V Former-commit-id: 6eec50c74dcbcc325ad6258228e19c19b4a03538
This commit is contained in:
commit
0b47c2a293
@ -209,6 +209,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||
| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_o |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||
|
@ -210,6 +210,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||
| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_o |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
@ -101,7 +102,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_input_ids.append(feature["input_ids"])
|
||||
|
||||
if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case
|
||||
if (
|
||||
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
||||
): # avoid process hanging in zero3/fsdp case
|
||||
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
|
||||
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
||||
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
|
||||
@ -150,6 +153,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||
|
||||
if "image_bound" in features: # for minicpmv inputs
|
||||
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
|
||||
features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0)
|
||||
new_features = {"data": features}
|
||||
new_features.update({"labels": features["labels"]})
|
||||
features = new_features
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import math
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
@ -253,6 +254,160 @@ class BasePlugin:
|
||||
return {}
|
||||
|
||||
|
||||
class CpmOPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
mm_inputs = {}
|
||||
|
||||
if len(videos) != 0:
|
||||
assert len(images) == 0, "Only support video and image sft seperately"
|
||||
max_slice_nums = 2
|
||||
use_image_id = False
|
||||
mm_inputs = self._get_mm_inputs([], videos, processor)
|
||||
else:
|
||||
max_slice_nums = image_processor.max_slice_nums
|
||||
use_image_id = image_processor.use_image_id
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, "{{image}}" * len(mm_inputs["pixel_values"][num_video_tokens - 1]), 1
|
||||
)
|
||||
|
||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
||||
|
||||
if num_image_tokens > 0:
|
||||
mm_inputs = self._get_mm_inputs(images, [], processor)
|
||||
|
||||
if mm_inputs:
|
||||
pattern = "(<image>./</image>)"
|
||||
image_sizes = mm_inputs["image_sizes"]
|
||||
|
||||
for index, message in enumerate(messages):
|
||||
text = message["content"]
|
||||
image_tags = re.findall(pattern, text)
|
||||
text_chunks = text.split(pattern)
|
||||
final_text = ""
|
||||
for i in range(len(image_tags)):
|
||||
final_text = (
|
||||
final_text
|
||||
+ text_chunks[i]
|
||||
+ image_processor.get_slice_image_placeholder(
|
||||
image_sizes[0][i],
|
||||
i,
|
||||
max_slice_nums,
|
||||
use_image_id,
|
||||
)
|
||||
)
|
||||
final_text += text_chunks[-1]
|
||||
messages[index]["content"] = final_text
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
**kwargs,
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
|
||||
mm_inputs = {}
|
||||
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||
)
|
||||
if "valid_image_nums_ls" in kwargs:
|
||||
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
|
||||
new_images = []
|
||||
idx = 0
|
||||
for valid_image_nums in valid_image_nums_ls:
|
||||
new_images.append(images[idx : idx + valid_image_nums])
|
||||
idx += valid_image_nums
|
||||
images = new_images
|
||||
|
||||
image_inputs = image_processor(
|
||||
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
||||
)
|
||||
mm_inputs.update(image_inputs)
|
||||
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||
)
|
||||
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
||||
mm_inputs.update(video_inputs)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
image_bounds_list = []
|
||||
valid_image_nums_ls = []
|
||||
|
||||
for input_ids in batch_ids:
|
||||
input_ids_ = torch.tensor(input_ids)
|
||||
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
||||
input_ids_ == processor.tokenizer.slice_start_id
|
||||
)
|
||||
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
|
||||
image_start_tokens = torch.where(start_cond)[0]
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(end_cond)[0]
|
||||
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
||||
valid_image_nums_ls.append(valid_image_nums)
|
||||
image_bounds = torch.hstack(
|
||||
[
|
||||
image_start_tokens[:valid_image_nums].unsqueeze(-1),
|
||||
image_end_tokens[:valid_image_nums].unsqueeze(-1),
|
||||
]
|
||||
)
|
||||
image_bounds_list.append(image_bounds)
|
||||
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
|
||||
mm_inputs.update({"image_bound": image_bounds_list})
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class LlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
@ -794,6 +949,7 @@ class MllamaPlugin(BasePlugin):
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"cpm_o": CpmOPlugin,
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
|
@ -566,6 +566,16 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="cpm_o",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
mm_plugin=get_mm_plugin(name="cpm_o", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
name="dbrx",
|
||||
|
@ -1163,6 +1163,17 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-V-2_6-Chat": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
|
||||
},
|
||||
},
|
||||
template="cpm_o",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B-v0.1": {
|
||||
|
@ -250,6 +250,12 @@ _register_composite_model(
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="minicpmv",
|
||||
vision_model_keys=["vpm", "apm", "resampler", "tts"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="paligemma",
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user