diff --git a/README.md b/README.md index 739a13af..7b70a89a 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | +| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_o | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | diff --git a/README_zh.md b/README_zh.md index e21560d2..80903fcb 100644 --- a/README_zh.md +++ b/README_zh.md @@ -210,6 +210,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | +| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_o | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index f360f0f5..dfd853ca 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence import torch import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence from transformers import DataCollatorForSeq2Seq from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER @@ -101,7 +102,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): batch_vidlens.append(len(videos)) batch_input_ids.append(feature["input_ids"]) - if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case + if ( + self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0 + ): # avoid process hanging in zero3/fsdp case fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}] fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))] fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor) @@ -150,6 +153,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if isinstance(features.get("pixel_values"), list): # for pixtral inputs features = features.data # use default_collate() instead of BatchEncoding.to() + if "image_bound" in features: # for minicpmv inputs + features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]] + features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0) + new_features = {"data": features} + new_features.update({"labels": features["labels"]}) + features = new_features + return features diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index a8c46d11..909ce7c0 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1,4 +1,5 @@ import math +import re from copy import deepcopy from io import BytesIO from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union @@ -253,6 +254,160 @@ class BasePlugin: return {} +class CpmOPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + num_video_tokens = 0 + messages = deepcopy(messages) + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + mm_inputs = {} + + if len(videos) != 0: + assert len(images) == 0, "Only support video and image sft seperately" + max_slice_nums = 2 + use_image_id = False + mm_inputs = self._get_mm_inputs([], videos, processor) + else: + max_slice_nums = image_processor.max_slice_nums + use_image_id = image_processor.use_image_id + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + num_image_tokens += 1 + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + + while VIDEO_PLACEHOLDER in content: + num_video_tokens += 1 + content = content.replace( + VIDEO_PLACEHOLDER, "{{image}}" * len(mm_inputs["pixel_values"][num_video_tokens - 1]), 1 + ) + + message["content"] = content.replace("{{image}}", "(./)") + + if num_image_tokens > 0: + mm_inputs = self._get_mm_inputs(images, [], processor) + + if mm_inputs: + pattern = "(./)" + image_sizes = mm_inputs["image_sizes"] + + for index, message in enumerate(messages): + text = message["content"] + image_tags = re.findall(pattern, text) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(image_tags)): + final_text = ( + final_text + + text_chunks[i] + + image_processor.get_slice_image_placeholder( + image_sizes[0][i], + i, + max_slice_nums, + use_image_id, + ) + ) + final_text += text_chunks[-1] + messages[index]["content"] = final_text + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + if len(videos) != num_video_tokens: + raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") + + return messages + + @override + def _get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: "ProcessorMixin", + **kwargs, + ) -> Dict[str, "torch.Tensor"]: + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + + mm_inputs = {} + + if len(images) != 0: + images = self._regularize_images( + images, + image_resolution=getattr(processor, "image_resolution", 512 * 512), + ) + if "valid_image_nums_ls" in kwargs: + valid_image_nums_ls = kwargs["valid_image_nums_ls"] + new_images = [] + idx = 0 + for valid_image_nums in valid_image_nums_ls: + new_images.append(images[idx : idx + valid_image_nums]) + idx += valid_image_nums + images = new_images + + image_inputs = image_processor( + images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt" + ) + mm_inputs.update(image_inputs) + + if len(videos) != 0: + videos = self._regularize_videos( + videos, + image_resolution=getattr(processor, "video_resolution", 128 * 128), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 64), + ) + video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") + mm_inputs.update(video_inputs) + + return mm_inputs + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + batch_ids: Sequence[List[int]], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + image_bounds_list = [] + valid_image_nums_ls = [] + + for input_ids in batch_ids: + input_ids_ = torch.tensor(input_ids) + start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( + input_ids_ == processor.tokenizer.slice_start_id + ) + end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id) + image_start_tokens = torch.where(start_cond)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(end_cond)[0] + valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + valid_image_nums_ls.append(valid_image_nums) + image_bounds = torch.hstack( + [ + image_start_tokens[:valid_image_nums].unsqueeze(-1), + image_end_tokens[:valid_image_nums].unsqueeze(-1), + ] + ) + image_bounds_list.append(image_bounds) + + mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls) + mm_inputs.update({"image_bound": image_bounds_list}) + return mm_inputs + + class LlavaPlugin(BasePlugin): @override def process_messages( @@ -794,6 +949,7 @@ class MllamaPlugin(BasePlugin): PLUGINS = { "base": BasePlugin, + "cpm_o": CpmOPlugin, "llava": LlavaPlugin, "llava_next": LlavaNextPlugin, "llava_next_video": LlavaNextVideoPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 682ccb10..58dcf561 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -566,6 +566,16 @@ _register_template( ) +_register_template( + name="cpm_o", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + stop_words=["<|im_end|>"], + mm_plugin=get_mm_plugin(name="cpm_o", image_token="", video_token="