From ae32c148d1adcbe86701babf8b3d5b64f262039b Mon Sep 17 00:00:00 2001 From: Zhangchi Feng <64362896+BUAADreamer@users.noreply.github.com> Date: Tue, 14 Jan 2025 00:26:19 +0800 Subject: [PATCH] Support new features of MiniCPM-V (#6626) * fix template name * tiny fix * support minicpm-o-2.6 Former-commit-id: 53034a61c7654358f46916cbc370910fb2aeff3b --- README.md | 2 +- README_zh.md | 2 +- setup.py | 10 + src/llamafactory/data/collator.py | 5 +- src/llamafactory/data/mm_plugin.py | 302 +++++++++---------- src/llamafactory/data/template.py | 22 +- src/llamafactory/extras/constants.py | 13 +- src/llamafactory/model/model_utils/visual.py | 1 + 8 files changed, 189 insertions(+), 168 deletions(-) diff --git a/README.md b/README.md index 4327c951..df9a254f 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | -| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_v | +| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_v | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | diff --git a/README_zh.md b/README_zh.md index bf4eb4d3..c9f28c49 100644 --- a/README_zh.md +++ b/README_zh.md @@ -210,7 +210,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | -| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_v | +| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_v | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | diff --git a/setup.py b/setup.py index 10d93551..75d2b7e3 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,16 @@ extra_require = { "badam": ["badam>=1.2.1"], "adam-mini": ["adam-mini"], "qwen": ["transformers_stream_generator"], + "minicpm_v": [ + "soundfile", + "torchvision", + "torchaudio", + "vector_quantize_pytorch", + "vocos", + "msgpack", + "referencing", + "jsonschema_specifications", + ], "modelscope": ["modelscope"], "openmind": ["openmind"], "swanlab": ["swanlab"], diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 01742bdc..c14f17e0 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -153,9 +153,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features = features.data # use default_collate() instead of BatchEncoding.to() if "image_bound" in features: # for minicpmv inputs - features["position_ids"] = ( - torch.arange(features["input_ids"].size(1)).long().unsqueeze(0).expand_as(features["input_ids"]) - ) + bsz, seq_length = features["input_ids"].shape + features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1) return {"data": features, "labels": features["labels"]} return features diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 7cae556a..e4447ee5 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -254,156 +254,6 @@ class BasePlugin: return {} -class CpmVPlugin(BasePlugin): - @override - def process_messages( - self, - messages: Sequence[Dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: - self._validate_input(images, videos) - num_image_tokens = 0 - num_video_tokens = 0 - messages = deepcopy(messages) - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - mm_inputs = {} - if len(images) != 0 and len(videos) != 0: - raise ValueError("MiniCPM-V model does not support input images and videos at the same time.") - - if len(videos) != 0: - max_slice_nums = 2 - use_image_id = False - mm_inputs = self._get_mm_inputs([], videos, processor) - else: - max_slice_nums = image_processor.max_slice_nums - use_image_id = image_processor.use_image_id - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - num_image_tokens += 1 - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) - - while VIDEO_PLACEHOLDER in content: - video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 - content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) - num_video_tokens += 1 - - message["content"] = content.replace("{{image}}", "(./)") - - if num_image_tokens > 0: - mm_inputs = self._get_mm_inputs(images, [], processor) - - if mm_inputs: - pattern = "(./)" - image_sizes = mm_inputs["image_sizes"] - - for index, message in enumerate(messages): - text = message["content"] - image_tags = re.findall(pattern, text) - text_chunks = text.split(pattern) - final_text = "" - for i in range(len(image_tags)): - final_text = ( - final_text - + text_chunks[i] - + image_processor.get_slice_image_placeholder( - image_sizes[0][i], i, max_slice_nums, use_image_id - ) - ) - final_text += text_chunks[-1] - messages[index]["content"] = final_text - - if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") - - if len(videos) != num_video_tokens: - raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") - - return messages - - @override - def _get_mm_inputs( - self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - processor: "ProcessorMixin", - **kwargs, - ) -> Dict[str, "torch.Tensor"]: - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - - mm_inputs = {} - if len(images) != 0: - images = self._regularize_images( - images, - image_resolution=getattr(processor, "image_resolution", 512 * 512), - ) - if "valid_image_nums_ls" in kwargs: - valid_image_nums_ls = kwargs["valid_image_nums_ls"] - new_images = [] - idx = 0 - for valid_image_nums in valid_image_nums_ls: - new_images.append(images[idx : idx + valid_image_nums]) - idx += valid_image_nums - - images = new_images - - image_inputs = image_processor( - images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt" - ) - mm_inputs.update(image_inputs) - - if len(videos) != 0: - videos = self._regularize_videos( - videos, - image_resolution=getattr(processor, "video_resolution", 128 * 128), - video_fps=getattr(processor, "video_fps", 2.0), - video_maxlen=getattr(processor, "video_maxlen", 64), - ) - video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") - mm_inputs.update(video_inputs) - - return mm_inputs - - @override - def get_mm_inputs( - self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - batch_ids: Sequence[List[int]], - processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos) - image_bounds_list = [] - valid_image_nums_ls = [] - for input_ids in batch_ids: - input_ids_ = torch.tensor(input_ids) - start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( - input_ids_ == processor.tokenizer.slice_start_id - ) - end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id) - image_start_tokens = torch.where(start_cond)[0] - image_start_tokens += 1 - image_end_tokens = torch.where(end_cond)[0] - valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) - valid_image_nums_ls.append(valid_image_nums) - image_bounds = torch.hstack( - [ - image_start_tokens[:valid_image_nums].unsqueeze(-1), - image_end_tokens[:valid_image_nums].unsqueeze(-1), - ] - ) - image_bounds_list.append(image_bounds) - - mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls) - mm_inputs.update({"image_bound": image_bounds_list}) - return mm_inputs - - class LlavaPlugin(BasePlugin): @override def process_messages( @@ -567,6 +417,156 @@ class LlavaNextVideoPlugin(BasePlugin): return self._get_mm_inputs(images, videos, processor) +class MiniCPMVPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + num_video_tokens = 0 + messages = deepcopy(messages) + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + mm_inputs = {} + if len(images) != 0 and len(videos) != 0: + raise ValueError("MiniCPM-V model does not support input images and videos at the same time.") + + if len(videos) != 0: + max_slice_nums = 2 + use_image_id = False + mm_inputs = self._get_mm_inputs([], videos, processor) + else: + max_slice_nums = image_processor.max_slice_nums + use_image_id = image_processor.use_image_id + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 + content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) + num_video_tokens += 1 + + message["content"] = content.replace("{{image}}", "(./)") + + if num_image_tokens > 0: + mm_inputs = self._get_mm_inputs(images, [], processor) + + if mm_inputs: + pattern = "(./)" + image_sizes = mm_inputs["image_sizes"] + + for index, message in enumerate(messages): + text = message["content"] + image_tags = re.findall(pattern, text) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(image_tags)): + final_text = ( + final_text + + text_chunks[i] + + image_processor.get_slice_image_placeholder( + image_sizes[0][i], i, max_slice_nums, use_image_id + ) + ) + + final_text += text_chunks[-1] + messages[index]["content"] = final_text + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + if len(videos) != num_video_tokens: + raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") + + return messages + + @override + def _get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: "ProcessorMixin", + **kwargs, + ) -> Dict[str, "torch.Tensor"]: + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_resolution=getattr(processor, "image_resolution", 512 * 512), + ) + if "valid_image_nums_ls" in kwargs: + valid_image_nums_ls = kwargs["valid_image_nums_ls"] + new_images = [] + idx = 0 + for valid_image_nums in valid_image_nums_ls: + new_images.append(images[idx : idx + valid_image_nums]) + idx += valid_image_nums + + images = new_images + + image_inputs = image_processor( + images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt" + ) + mm_inputs.update(image_inputs) + + if len(videos) != 0: + videos = self._regularize_videos( + videos, + image_resolution=getattr(processor, "video_resolution", 128 * 128), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 64), + ) + video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") + mm_inputs.update(video_inputs) + + return mm_inputs + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + batch_ids: Sequence[List[int]], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + image_bounds_list = [] + valid_image_nums_ls = [] + for input_ids in batch_ids: + input_ids_ = torch.tensor(input_ids) + start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( + input_ids_ == processor.tokenizer.slice_start_id + ) + end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id) + image_start_tokens = torch.where(start_cond)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(end_cond)[0] + valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + valid_image_nums_ls.append(valid_image_nums) + image_bounds = torch.hstack( + [ + image_start_tokens[:valid_image_nums].unsqueeze(-1), + image_end_tokens[:valid_image_nums].unsqueeze(-1), + ] + ) + image_bounds_list.append(image_bounds) + + mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls) + mm_inputs.update({"image_bound": image_bounds_list}) + return mm_inputs + + class PaliGemmaPlugin(BasePlugin): @override def process_messages( @@ -945,10 +945,10 @@ class MllamaPlugin(BasePlugin): PLUGINS = { "base": BasePlugin, - "cpm_v": CpmVPlugin, "llava": LlavaPlugin, "llava_next": LlavaNextPlugin, "llava_next_video": LlavaNextVideoPlugin, + "minicpm_v": MiniCPMVPlugin, "paligemma": PaliGemmaPlugin, "pixtral": PixtralPlugin, "qwen2_vl": Qwen2vlPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 05eb0cda..c9340afc 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -576,17 +576,6 @@ _register_template( ) -# copied from chatml template -_register_template( - name="cpm_v", - format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), - format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - stop_words=["<|im_end|>"], - mm_plugin=get_mm_plugin(name="cpm_v", image_token="", video_token="