diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 1e670b92c..38db80733 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -205,6 +205,9 @@ class HuggingfaceEngine(BaseEngine): gen_kwargs.pop("image_sizes", None) + if getattr(model.config, "model_type", None) == "minicpmv4_6": + gen_kwargs["downsample_mode"] = os.getenv("DOWNSAMPLE_MODE", "16x") + return gen_kwargs, prompt_length @staticmethod diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index af234d99b..d6ecaddc3 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -17,6 +17,7 @@ import copy import inspect +import os from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, Optional @@ -474,6 +475,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1) return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]} + if ( + self.model is not None + and getattr(self.model.config, "model_type", None) == "minicpmv4_6" + and "target_sizes" in features + ): # for minicpmv4_6 with new transformers (NaViT API, no image_bound) + features["downsample_mode"] = os.getenv("DOWNSAMPLE_MODE", "16x") + return features diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 6827ea400..dc45d431c 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1209,6 +1209,23 @@ class LlavaNextVideoPlugin(BasePlugin): @dataclass class MiniCPMVPlugin(BasePlugin): + def _resolve_token_id(self, tokenizer: Any, attr_name: str, token_text: str | None = None) -> int | None: + token_id = getattr(tokenizer, attr_name, None) + if isinstance(token_id, int) and token_id >= 0: + return token_id + + if token_text is None or not hasattr(tokenizer, "convert_tokens_to_ids"): + return None + + converted_id = tokenizer.convert_tokens_to_ids(token_text) + if isinstance(converted_id, list): + converted_id = converted_id[0] if len(converted_id) else None + + if isinstance(converted_id, int) and converted_id >= 0: + return converted_id + + return None + @override def _get_mm_inputs( self, @@ -1220,6 +1237,8 @@ class MiniCPMVPlugin(BasePlugin): ) -> dict[str, "torch.Tensor"]: image_processor: BaseImageProcessor = getattr(processor, "image_processor") mm_inputs = {} + preprocess_params = inspect.signature(image_processor.preprocess).parameters + downsample_mode = os.getenv("DOWNSAMPLE_MODE", "16x") if "downsample_mode" in preprocess_params else None if len(images) != 0: images = self._regularize_images( images, @@ -1236,9 +1255,15 @@ class MiniCPMVPlugin(BasePlugin): images = new_images - image_inputs = image_processor( - images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt" - ) + image_processor_kwargs = { + "do_pad": True, + "max_slice_nums": image_processor.max_slice_nums, + "return_tensors": "pt", + } + if downsample_mode is not None: + image_processor_kwargs["downsample_mode"] = downsample_mode + + image_inputs = image_processor(images, **image_processor_kwargs) mm_inputs.update(image_inputs) if len(videos) != 0: @@ -1249,7 +1274,15 @@ class MiniCPMVPlugin(BasePlugin): video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), )["videos"] - video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") + video_processor_kwargs = { + "do_pad": True, + "max_slice_nums": 2, + "return_tensors": "pt", + } + if downsample_mode is not None: + video_processor_kwargs["downsample_mode"] = downsample_mode + + video_inputs = image_processor(videos, **video_processor_kwargs) mm_inputs.update(video_inputs) if len(audios) != 0: @@ -1334,7 +1367,8 @@ class MiniCPMVPlugin(BasePlugin): if self.expand_mm_tokens and mm_inputs: pattern = "(./)" - image_sizes = mm_inputs["image_sizes"] + image_sizes = mm_inputs.get("image_sizes") + image_grids = mm_inputs.get("grids") idx = 0 for index, message in enumerate(messages): text = message["content"] @@ -1342,13 +1376,21 @@ class MiniCPMVPlugin(BasePlugin): text_chunks = text.split(pattern) final_text = "" for i in range(len(image_tags)): - final_text = ( - final_text - + text_chunks[i] - + image_processor.get_slice_image_placeholder( - image_sizes[0][idx], idx, max_slice_nums, use_image_id + grid = image_grids[0][idx] if image_grids and len(image_grids[0]) > idx else [1, 1] + image_size = image_sizes[0][idx] if image_sizes and len(image_sizes[0]) > idx else None + + placeholder_fn = image_processor.get_slice_image_placeholder + if image_size is not None: + image_placeholder = placeholder_fn( + image_size, + image_idx=idx, + max_slice_nums=max_slice_nums, + use_image_id=use_image_id, ) - ) + else: + image_placeholder = placeholder_fn(grid) + + final_text = final_text + text_chunks[i] + image_placeholder idx += 1 final_text += text_chunks[-1] @@ -1385,15 +1427,25 @@ class MiniCPMVPlugin(BasePlugin): processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) + tokenizer = processor.tokenizer + im_start_id = self._resolve_token_id(tokenizer, "im_start_id", "") + slice_start_id = self._resolve_token_id(tokenizer, "slice_start_id", "") + im_end_id = self._resolve_token_id(tokenizer, "im_end_id", "") + slice_end_id = self._resolve_token_id(tokenizer, "slice_end_id", "") + if None in (im_start_id, slice_start_id, im_end_id, slice_end_id): + raise AttributeError( + "Cannot resolve MiniCPM image boundary token ids from tokenizer. " + "Expected attributes (im_start_id/slice_start_id/im_end_id/slice_end_id) " + "or corresponding special tokens (, , , )." + ) + # image bound image_bounds_list = [] valid_image_nums_ls = [] for i, input_ids in enumerate(batch_ids): input_ids_ = torch.tensor(input_ids) - start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( - input_ids_ == processor.tokenizer.slice_start_id - ) - end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id) + start_cond = (input_ids_ == im_start_id) | (input_ids_ == slice_start_id) + end_cond = (input_ids_ == im_end_id) | (input_ids_ == slice_end_id) image_start_tokens = torch.where(start_cond)[0] image_start_tokens += 1 image_end_tokens = torch.where(end_cond)[0] @@ -1414,6 +1466,16 @@ class MiniCPMVPlugin(BasePlugin): mm_inputs.update({"image_bound": image_bounds_list}) if len(audios) > 0: + audio_start_id = self._resolve_token_id(tokenizer, "audio_start_id", "") + spk_start_id = self._resolve_token_id(tokenizer, "spk_start_id", "") + spk_end_id = self._resolve_token_id(tokenizer, "spk_end_id", "") + if None in (audio_start_id, audio_end_id, spk_start_id, spk_end_id): + raise AttributeError( + "Cannot resolve MiniCPM audio/speaker boundary token ids from tokenizer. " + "Expected *_id attributes or corresponding special tokens." + ) + # audio bound audio_bounds_ls = [] spk_bounds_ls = [] @@ -1421,15 +1483,15 @@ class MiniCPMVPlugin(BasePlugin): for input_ids, audiolen in zip(batch_ids, audlens): input_ids_ = torch.tensor(input_ids) - audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0] - audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0] + audio_start_idx = torch.where(input_ids_ == audio_start_id)[0] + audio_end_idx = torch.where(input_ids_ == audio_end_id)[0] assert len(audio_start_idx) == len(audio_end_idx) audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)]) audio_bounds_ls.append(audio_bounds) valid_audio_nums_ls.append(audiolen) - spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0] - spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0] + spk_start_idx = torch.where(input_ids_ == spk_start_id)[0] + spk_end_idx = torch.where(input_ids_ == spk_end_id)[0] assert len(spk_start_idx) == len(spk_end_idx) spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)]) spk_bounds_ls.append(spk_bounds) @@ -1441,6 +1503,255 @@ class MiniCPMVPlugin(BasePlugin): return mm_inputs +@dataclass +class MiniCPMV4_6Plugin(BasePlugin): + """Plugin for MiniCPM-V-4.6 with new transformers (NaViT vision + get_placeholder_mask API).""" + + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + **kwargs, + ) -> dict[str, "torch.Tensor"]: + image_processor = getattr(processor, "image_processor") + video_processor = getattr(processor, "video_processor", None) + mm_inputs = {} + preprocess_params = inspect.signature(image_processor.preprocess).parameters + downsample_mode = os.getenv("DOWNSAMPLE_MODE", "16x") if "downsample_mode" in preprocess_params else None + + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + image_processor_kwargs = { + "max_slice_nums": getattr(image_processor, "max_slice_nums", 9), + "return_tensors": "pt", + } + if downsample_mode is not None: + image_processor_kwargs["downsample_mode"] = downsample_mode + image_inputs = image_processor(images, **image_processor_kwargs) + mm_inputs.update(image_inputs) + + if len(videos) != 0: + videos = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + )["videos"] + if video_processor is not None: + video_processor_kwargs = { + "max_slice_nums": 2, + "return_tensors": "pt", + } + if downsample_mode is not None: + video_processor_kwargs["downsample_mode"] = downsample_mode + video_inputs = video_processor(videos, **video_processor_kwargs) + mm_inputs["pixel_values_videos"] = video_inputs["pixel_values_videos"] + mm_inputs["target_sizes_videos"] = video_inputs["target_sizes_videos"] + else: + # Fallback to image processor for video + video_processor_kwargs = { + "max_slice_nums": 2, + "return_tensors": "pt", + } + if downsample_mode is not None: + video_processor_kwargs["downsample_mode"] = downsample_mode + video_inputs = image_processor(videos, **video_processor_kwargs) + mm_inputs["pixel_values_videos"] = video_inputs["pixel_values"] + mm_inputs["target_sizes_videos"] = video_inputs["target_sizes"] + + if len(audios) != 0: + audios = self._regularize_audios( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + )["audios"] + audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract( + [audios], + chunk_input=True, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + ) + audio_feature_lens = [ + x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens + ] + mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) + if kwargs.get("ret_phs", False): + mm_inputs.update({"audio_phs": audio_phs}) + + return mm_inputs + + def _build_v4_6_placeholder( + self, + image_inputs: dict[str, Any], + image_idx: int, + use_image_id: bool, + processor: "MMProcessor", + ) -> str: + """Build image placeholder for MiniCPM-V-4.6 using NaViT token count computation.""" + grids = image_inputs.get("grids", [[0, 0]]) + num_patches_per_image = image_inputs.get("num_patches_per_image", [1]) + target_sizes = image_inputs.get("target_sizes") + + downsample_mode = os.getenv("DOWNSAMPLE_MODE") + if downsample_mode is None: + image_processor = getattr(processor, "image_processor") + downsample_mode = getattr(image_processor, "downsample_mode", "16x") + token_divisor = 4 if downsample_mode == "4x" else 16 + + flat_index = 0 + for idx in range(image_idx): + flat_index += num_patches_per_image[idx] + n_patches = num_patches_per_image[image_idx] + + img_target_sizes = target_sizes[flat_index : flat_index + n_patches] + num_tokens_per_patch = img_target_sizes.prod(-1) // token_divisor + num_rows, num_cols = grids[image_idx] + + image_start = getattr(processor, "image_start_token", "") + image_end = getattr(processor, "image_end_token", "") + slice_start = getattr(processor, "slice_start_token", "") + slice_end = getattr(processor, "slice_end_token", "") + image_id_start = getattr(processor, "image_id_start_token", "") + image_id_end = getattr(processor, "image_id_end_token", "") + image_token = ( + getattr(processor, "image_token", None) + or getattr(getattr(processor, "tokenizer", None), "image_token", None) + or "" + ) + + image_placeholder = image_start + "<|ph|>" * int(num_tokens_per_patch[0]) + image_end + if use_image_id: + image_placeholder = f"{image_id_start}{image_idx}{image_id_end}" + image_placeholder + + slice_mode = getattr(processor, "slice_mode", True) + if slice_mode and num_rows > 0 and num_cols > 0: + per_slice_tokens = int(num_tokens_per_patch[1]) if len(num_tokens_per_patch) > 1 else 0 + slice_placeholder = slice_start + "<|ph|>" * per_slice_tokens + slice_end + slices = [slice_placeholder * num_cols for _ in range(num_rows)] + image_placeholder += "\n".join(slices) + + return image_placeholder.replace("<|ph|>", image_token) + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 + messages = deepcopy(messages) + mm_inputs, audio_inputs = {}, {} + if len(images) != 0 and len(videos) != 0: + raise ValueError("MiniCPM-V model does not support input images and videos at the same time.") + + use_image_id = getattr(processor, "default_use_image_id", True) + + if len(videos) != 0: + use_image_id = False + mm_inputs = self._get_mm_inputs([], videos, [], processor) + + for i, message in enumerate(messages): + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + num_frames = 1 + if "num_frames_per_video" in mm_inputs: + num_frames = sum(mm_inputs["num_frames_per_video"]) + content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * num_frames, 1) + num_video_tokens += 1 + + while AUDIO_PLACEHOLDER in content: + content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1) + num_audio_tokens += 1 + + message["content"] = content.replace("{{image}}", "(./)").replace( + "{{audio}}", "()" + ) + + if len(images): + mm_inputs = self._get_mm_inputs(images, [], [], processor) + + if len(audios): + audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True) + + if self.expand_mm_tokens and mm_inputs: + pattern = "(./)" + idx = 0 + for index, message in enumerate(messages): + text = message["content"] + image_tags = re.findall(pattern, text) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(image_tags)): + image_placeholder = self._build_v4_6_placeholder(mm_inputs, idx, use_image_id, processor) + final_text = final_text + text_chunks[i] + image_placeholder + idx += 1 + final_text += text_chunks[-1] + messages[index]["content"] = final_text + + if self.expand_mm_tokens and audio_inputs: + pattern = "()" + idx = 0 + for index, message in enumerate(messages): + text = message["content"] + audio_tags = re.findall(pattern, text) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(audio_tags)): + audio_placeholder = audio_inputs["audio_phs"][0][idx] + final_text = final_text + text_chunks[i] + audio_placeholder + idx += 1 + final_text += text_chunks[-1] + messages[index]["content"] = final_text + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + + # v4.6 does NOT use image_bound — the model finds image tokens via get_placeholder_mask + # Ensure target_sizes key name matches the model's expected input + if "target_sizes" not in mm_inputs and "tgt_sizes" in mm_inputs: + mm_inputs["target_sizes"] = mm_inputs.pop("tgt_sizes") + + if "target_sizes" not in mm_inputs: + mm_inputs["target_sizes"] = torch.empty(0, 2, dtype=torch.int32) + + if "pixel_values" not in mm_inputs: + mm_inputs["pixel_values"] = torch.empty(1, 3, 14, 0) + + if len(audios) > 0: + audio_inputs = self._get_mm_inputs([], [], audios, processor) + mm_inputs.update(audio_inputs) + + return mm_inputs + + @dataclass class MllamaPlugin(BasePlugin): @override @@ -2695,6 +3006,7 @@ PLUGINS = { "llava_next_video": LlavaNextVideoPlugin, "lfm2_vl": LFMVLPlugin, "minicpm_v": MiniCPMVPlugin, + "minicpm_v_4_6": MiniCPMV4_6Plugin, "mllama": MllamaPlugin, "paligemma": PaliGemmaPlugin, "pixtral": PixtralPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index cb90eb3ec..89b70d114 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1704,6 +1704,17 @@ register_template( ) +register_template( + name="minicpm_v_4_6", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + stop_words=["<|im_end|>"], + default_system="You are a helpful assistant.", + mm_plugin=get_mm_plugin(name="minicpm_v_4_6", image_token="", video_token="