mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[fix] Fix MiniCPM-V-4.6 image preprocessing behavior (#10478)
This commit is contained in:
@@ -205,9 +205,6 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
gen_kwargs.pop("image_sizes", None)
|
||||
|
||||
if getattr(model.config, "model_type", None) == "minicpmv4_6":
|
||||
gen_kwargs["downsample_mode"] = os.getenv("DOWNSAMPLE_MODE", "16x")
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
@@ -475,13 +474,6 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
|
||||
return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]}
|
||||
|
||||
if (
|
||||
self.model is not None
|
||||
and getattr(self.model.config, "model_type", None) == "minicpmv4_6"
|
||||
and "target_sizes" in features
|
||||
): # for minicpmv4_6 with new transformers (NaViT API, no image_bound)
|
||||
features["downsample_mode"] = os.getenv("DOWNSAMPLE_MODE", "16x")
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
||||
@@ -1209,23 +1209,6 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
|
||||
@dataclass
|
||||
class MiniCPMVPlugin(BasePlugin):
|
||||
def _resolve_token_id(self, tokenizer: Any, attr_name: str, token_text: str | None = None) -> int | None:
|
||||
token_id = getattr(tokenizer, attr_name, None)
|
||||
if isinstance(token_id, int) and token_id >= 0:
|
||||
return token_id
|
||||
|
||||
if token_text is None or not hasattr(tokenizer, "convert_tokens_to_ids"):
|
||||
return None
|
||||
|
||||
converted_id = tokenizer.convert_tokens_to_ids(token_text)
|
||||
if isinstance(converted_id, list):
|
||||
converted_id = converted_id[0] if len(converted_id) else None
|
||||
|
||||
if isinstance(converted_id, int) and converted_id >= 0:
|
||||
return converted_id
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
@@ -1237,8 +1220,6 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
mm_inputs = {}
|
||||
preprocess_params = inspect.signature(image_processor.preprocess).parameters
|
||||
downsample_mode = os.getenv("DOWNSAMPLE_MODE", "16x") if "downsample_mode" in preprocess_params else None
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
@@ -1255,15 +1236,9 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
|
||||
images = new_images
|
||||
|
||||
image_processor_kwargs = {
|
||||
"do_pad": True,
|
||||
"max_slice_nums": image_processor.max_slice_nums,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
if downsample_mode is not None:
|
||||
image_processor_kwargs["downsample_mode"] = downsample_mode
|
||||
|
||||
image_inputs = image_processor(images, **image_processor_kwargs)
|
||||
image_inputs = image_processor(
|
||||
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
||||
)
|
||||
mm_inputs.update(image_inputs)
|
||||
|
||||
if len(videos) != 0:
|
||||
@@ -1274,15 +1249,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)["videos"]
|
||||
video_processor_kwargs = {
|
||||
"do_pad": True,
|
||||
"max_slice_nums": 2,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
if downsample_mode is not None:
|
||||
video_processor_kwargs["downsample_mode"] = downsample_mode
|
||||
|
||||
video_inputs = image_processor(videos, **video_processor_kwargs)
|
||||
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
||||
mm_inputs.update(video_inputs)
|
||||
|
||||
if len(audios) != 0:
|
||||
@@ -1367,8 +1334,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
|
||||
if self.expand_mm_tokens and mm_inputs:
|
||||
pattern = "(<image>./</image>)"
|
||||
image_sizes = mm_inputs.get("image_sizes")
|
||||
image_grids = mm_inputs.get("grids")
|
||||
image_sizes = mm_inputs["image_sizes"]
|
||||
idx = 0
|
||||
for index, message in enumerate(messages):
|
||||
text = message["content"]
|
||||
@@ -1376,21 +1342,13 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
text_chunks = text.split(pattern)
|
||||
final_text = ""
|
||||
for i in range(len(image_tags)):
|
||||
grid = image_grids[0][idx] if image_grids and len(image_grids[0]) > idx else [1, 1]
|
||||
image_size = image_sizes[0][idx] if image_sizes and len(image_sizes[0]) > idx else None
|
||||
|
||||
placeholder_fn = image_processor.get_slice_image_placeholder
|
||||
if image_size is not None:
|
||||
image_placeholder = placeholder_fn(
|
||||
image_size,
|
||||
image_idx=idx,
|
||||
max_slice_nums=max_slice_nums,
|
||||
use_image_id=use_image_id,
|
||||
final_text = (
|
||||
final_text
|
||||
+ text_chunks[i]
|
||||
+ image_processor.get_slice_image_placeholder(
|
||||
image_sizes[0][idx], idx, max_slice_nums, use_image_id
|
||||
)
|
||||
else:
|
||||
image_placeholder = placeholder_fn(grid)
|
||||
|
||||
final_text = final_text + text_chunks[i] + image_placeholder
|
||||
)
|
||||
idx += 1
|
||||
|
||||
final_text += text_chunks[-1]
|
||||
@@ -1427,25 +1385,15 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
tokenizer = processor.tokenizer
|
||||
im_start_id = self._resolve_token_id(tokenizer, "im_start_id", "<image>")
|
||||
slice_start_id = self._resolve_token_id(tokenizer, "slice_start_id", "<slice>")
|
||||
im_end_id = self._resolve_token_id(tokenizer, "im_end_id", "</image>")
|
||||
slice_end_id = self._resolve_token_id(tokenizer, "slice_end_id", "</slice>")
|
||||
if None in (im_start_id, slice_start_id, im_end_id, slice_end_id):
|
||||
raise AttributeError(
|
||||
"Cannot resolve MiniCPM image boundary token ids from tokenizer. "
|
||||
"Expected attributes (im_start_id/slice_start_id/im_end_id/slice_end_id) "
|
||||
"or corresponding special tokens (<image>, <slice>, </image>, </slice>)."
|
||||
)
|
||||
|
||||
# image bound
|
||||
image_bounds_list = []
|
||||
valid_image_nums_ls = []
|
||||
for i, input_ids in enumerate(batch_ids):
|
||||
input_ids_ = torch.tensor(input_ids)
|
||||
start_cond = (input_ids_ == im_start_id) | (input_ids_ == slice_start_id)
|
||||
end_cond = (input_ids_ == im_end_id) | (input_ids_ == slice_end_id)
|
||||
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
||||
input_ids_ == processor.tokenizer.slice_start_id
|
||||
)
|
||||
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
|
||||
image_start_tokens = torch.where(start_cond)[0]
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(end_cond)[0]
|
||||
@@ -1466,16 +1414,6 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
mm_inputs.update({"image_bound": image_bounds_list})
|
||||
|
||||
if len(audios) > 0:
|
||||
audio_start_id = self._resolve_token_id(tokenizer, "audio_start_id", "<audio>")
|
||||
audio_end_id = self._resolve_token_id(tokenizer, "audio_end_id", "</audio>")
|
||||
spk_start_id = self._resolve_token_id(tokenizer, "spk_start_id", "<spk>")
|
||||
spk_end_id = self._resolve_token_id(tokenizer, "spk_end_id", "</spk>")
|
||||
if None in (audio_start_id, audio_end_id, spk_start_id, spk_end_id):
|
||||
raise AttributeError(
|
||||
"Cannot resolve MiniCPM audio/speaker boundary token ids from tokenizer. "
|
||||
"Expected *_id attributes or corresponding special tokens."
|
||||
)
|
||||
|
||||
# audio bound
|
||||
audio_bounds_ls = []
|
||||
spk_bounds_ls = []
|
||||
@@ -1483,15 +1421,15 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
|
||||
for input_ids, audiolen in zip(batch_ids, audlens):
|
||||
input_ids_ = torch.tensor(input_ids)
|
||||
audio_start_idx = torch.where(input_ids_ == audio_start_id)[0]
|
||||
audio_end_idx = torch.where(input_ids_ == audio_end_id)[0]
|
||||
audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0]
|
||||
audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0]
|
||||
assert len(audio_start_idx) == len(audio_end_idx)
|
||||
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
|
||||
audio_bounds_ls.append(audio_bounds)
|
||||
valid_audio_nums_ls.append(audiolen)
|
||||
|
||||
spk_start_idx = torch.where(input_ids_ == spk_start_id)[0]
|
||||
spk_end_idx = torch.where(input_ids_ == spk_end_id)[0]
|
||||
spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
|
||||
spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
|
||||
assert len(spk_start_idx) == len(spk_end_idx)
|
||||
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
|
||||
spk_bounds_ls.append(spk_bounds)
|
||||
@@ -1500,8 +1438,6 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
mm_inputs.update(audio_inputs)
|
||||
mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiniCPMV4_6Plugin(BasePlugin):
|
||||
@@ -1518,59 +1454,23 @@ class MiniCPMV4_6Plugin(BasePlugin):
|
||||
image_processor = getattr(processor, "image_processor")
|
||||
video_processor = getattr(processor, "video_processor", None)
|
||||
mm_inputs = {}
|
||||
preprocess_params = inspect.signature(image_processor.preprocess).parameters
|
||||
downsample_mode = os.getenv("DOWNSAMPLE_MODE", "16x") if "downsample_mode" in preprocess_params else None
|
||||
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)["images"]
|
||||
image_processor_kwargs = {
|
||||
"max_slice_nums": getattr(image_processor, "max_slice_nums", 9),
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
if downsample_mode is not None:
|
||||
image_processor_kwargs["downsample_mode"] = downsample_mode
|
||||
image_inputs = image_processor(images, **image_processor_kwargs)
|
||||
mm_inputs.update(image_inputs)
|
||||
# The image_processor ignores downsample_mode; target_sizes are always based on patch_size.
|
||||
# downsample_mode only affects the token divisor in _build_v4_6_placeholder and model forward.
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)["videos"]
|
||||
if video_processor is not None:
|
||||
video_processor_kwargs = {
|
||||
"max_slice_nums": 2,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
if downsample_mode is not None:
|
||||
video_processor_kwargs["downsample_mode"] = downsample_mode
|
||||
video_inputs = video_processor(videos, **video_processor_kwargs)
|
||||
video_inputs = video_processor(videos, return_tensors="pt")
|
||||
mm_inputs["pixel_values_videos"] = video_inputs["pixel_values_videos"]
|
||||
mm_inputs["target_sizes_videos"] = video_inputs["target_sizes_videos"]
|
||||
else:
|
||||
# Fallback to image processor for video
|
||||
video_processor_kwargs = {
|
||||
"max_slice_nums": 2,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
if downsample_mode is not None:
|
||||
video_processor_kwargs["downsample_mode"] = downsample_mode
|
||||
video_inputs = image_processor(videos, **video_processor_kwargs)
|
||||
video_inputs = image_processor(videos, return_tensors="pt")
|
||||
mm_inputs["pixel_values_videos"] = video_inputs["pixel_values"]
|
||||
mm_inputs["target_sizes_videos"] = video_inputs["target_sizes"]
|
||||
|
||||
if len(audios) != 0:
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
)["audios"]
|
||||
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
|
||||
[audios],
|
||||
chunk_input=True,
|
||||
@@ -1745,6 +1645,12 @@ class MiniCPMV4_6Plugin(BasePlugin):
|
||||
if "pixel_values" not in mm_inputs:
|
||||
mm_inputs["pixel_values"] = torch.empty(1, 3, 14, 0)
|
||||
|
||||
# Pass downsample_mode to model forward so it matches the placeholder divisor
|
||||
_ds = os.getenv("DOWNSAMPLE_MODE")
|
||||
if _ds is None:
|
||||
_ds = getattr(getattr(processor, "image_processor", None), "downsample_mode", "16x")
|
||||
mm_inputs["downsample_mode"] = _ds
|
||||
|
||||
if len(audios) > 0:
|
||||
audio_inputs = self._get_mm_inputs([], [], audios, processor)
|
||||
mm_inputs.update(audio_inputs)
|
||||
|
||||
Reference in New Issue
Block a user