|
|
|
|
@@ -1209,6 +1209,23 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
def _resolve_token_id(self, tokenizer: Any, attr_name: str, token_text: str | None = None) -> int | None:
|
|
|
|
|
token_id = getattr(tokenizer, attr_name, None)
|
|
|
|
|
if isinstance(token_id, int) and token_id >= 0:
|
|
|
|
|
return token_id
|
|
|
|
|
|
|
|
|
|
if token_text is None or not hasattr(tokenizer, "convert_tokens_to_ids"):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
converted_id = tokenizer.convert_tokens_to_ids(token_text)
|
|
|
|
|
if isinstance(converted_id, list):
|
|
|
|
|
converted_id = converted_id[0] if len(converted_id) else None
|
|
|
|
|
|
|
|
|
|
if isinstance(converted_id, int) and converted_id >= 0:
|
|
|
|
|
return converted_id
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def _get_mm_inputs(
|
|
|
|
|
self,
|
|
|
|
|
@@ -1220,6 +1237,8 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
) -> dict[str, "torch.Tensor"]:
|
|
|
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
|
|
|
|
mm_inputs = {}
|
|
|
|
|
preprocess_params = inspect.signature(image_processor.preprocess).parameters
|
|
|
|
|
downsample_mode = os.getenv("DOWNSAMPLE_MODE", "16x") if "downsample_mode" in preprocess_params else None
|
|
|
|
|
if len(images) != 0:
|
|
|
|
|
images = self._regularize_images(
|
|
|
|
|
images,
|
|
|
|
|
@@ -1236,9 +1255,15 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
|
|
|
|
|
images = new_images
|
|
|
|
|
|
|
|
|
|
image_inputs = image_processor(
|
|
|
|
|
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
|
|
|
|
)
|
|
|
|
|
image_processor_kwargs = {
|
|
|
|
|
"do_pad": True,
|
|
|
|
|
"max_slice_nums": image_processor.max_slice_nums,
|
|
|
|
|
"return_tensors": "pt",
|
|
|
|
|
}
|
|
|
|
|
if downsample_mode is not None:
|
|
|
|
|
image_processor_kwargs["downsample_mode"] = downsample_mode
|
|
|
|
|
|
|
|
|
|
image_inputs = image_processor(images, **image_processor_kwargs)
|
|
|
|
|
mm_inputs.update(image_inputs)
|
|
|
|
|
|
|
|
|
|
if len(videos) != 0:
|
|
|
|
|
@@ -1249,7 +1274,15 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
video_fps=getattr(processor, "video_fps", 2.0),
|
|
|
|
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
|
|
|
|
)["videos"]
|
|
|
|
|
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
|
|
|
|
video_processor_kwargs = {
|
|
|
|
|
"do_pad": True,
|
|
|
|
|
"max_slice_nums": 2,
|
|
|
|
|
"return_tensors": "pt",
|
|
|
|
|
}
|
|
|
|
|
if downsample_mode is not None:
|
|
|
|
|
video_processor_kwargs["downsample_mode"] = downsample_mode
|
|
|
|
|
|
|
|
|
|
video_inputs = image_processor(videos, **video_processor_kwargs)
|
|
|
|
|
mm_inputs.update(video_inputs)
|
|
|
|
|
|
|
|
|
|
if len(audios) != 0:
|
|
|
|
|
@@ -1334,7 +1367,8 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
|
|
|
|
|
if self.expand_mm_tokens and mm_inputs:
|
|
|
|
|
pattern = "(<image>./</image>)"
|
|
|
|
|
image_sizes = mm_inputs["image_sizes"]
|
|
|
|
|
image_sizes = mm_inputs.get("image_sizes")
|
|
|
|
|
image_grids = mm_inputs.get("grids")
|
|
|
|
|
idx = 0
|
|
|
|
|
for index, message in enumerate(messages):
|
|
|
|
|
text = message["content"]
|
|
|
|
|
@@ -1342,13 +1376,21 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
text_chunks = text.split(pattern)
|
|
|
|
|
final_text = ""
|
|
|
|
|
for i in range(len(image_tags)):
|
|
|
|
|
final_text = (
|
|
|
|
|
final_text
|
|
|
|
|
+ text_chunks[i]
|
|
|
|
|
+ image_processor.get_slice_image_placeholder(
|
|
|
|
|
image_sizes[0][idx], idx, max_slice_nums, use_image_id
|
|
|
|
|
grid = image_grids[0][idx] if image_grids and len(image_grids[0]) > idx else [1, 1]
|
|
|
|
|
image_size = image_sizes[0][idx] if image_sizes and len(image_sizes[0]) > idx else None
|
|
|
|
|
|
|
|
|
|
placeholder_fn = image_processor.get_slice_image_placeholder
|
|
|
|
|
if image_size is not None:
|
|
|
|
|
image_placeholder = placeholder_fn(
|
|
|
|
|
image_size,
|
|
|
|
|
image_idx=idx,
|
|
|
|
|
max_slice_nums=max_slice_nums,
|
|
|
|
|
use_image_id=use_image_id,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
image_placeholder = placeholder_fn(grid)
|
|
|
|
|
|
|
|
|
|
final_text = final_text + text_chunks[i] + image_placeholder
|
|
|
|
|
idx += 1
|
|
|
|
|
|
|
|
|
|
final_text += text_chunks[-1]
|
|
|
|
|
@@ -1385,15 +1427,25 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
tokenizer = processor.tokenizer
|
|
|
|
|
im_start_id = self._resolve_token_id(tokenizer, "im_start_id", "<image>")
|
|
|
|
|
slice_start_id = self._resolve_token_id(tokenizer, "slice_start_id", "<slice>")
|
|
|
|
|
im_end_id = self._resolve_token_id(tokenizer, "im_end_id", "</image>")
|
|
|
|
|
slice_end_id = self._resolve_token_id(tokenizer, "slice_end_id", "</slice>")
|
|
|
|
|
if None in (im_start_id, slice_start_id, im_end_id, slice_end_id):
|
|
|
|
|
raise AttributeError(
|
|
|
|
|
"Cannot resolve MiniCPM image boundary token ids from tokenizer. "
|
|
|
|
|
"Expected attributes (im_start_id/slice_start_id/im_end_id/slice_end_id) "
|
|
|
|
|
"or corresponding special tokens (<image>, <slice>, </image>, </slice>)."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# image bound
|
|
|
|
|
image_bounds_list = []
|
|
|
|
|
valid_image_nums_ls = []
|
|
|
|
|
for i, input_ids in enumerate(batch_ids):
|
|
|
|
|
input_ids_ = torch.tensor(input_ids)
|
|
|
|
|
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
|
|
|
|
input_ids_ == processor.tokenizer.slice_start_id
|
|
|
|
|
)
|
|
|
|
|
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
|
|
|
|
|
start_cond = (input_ids_ == im_start_id) | (input_ids_ == slice_start_id)
|
|
|
|
|
end_cond = (input_ids_ == im_end_id) | (input_ids_ == slice_end_id)
|
|
|
|
|
image_start_tokens = torch.where(start_cond)[0]
|
|
|
|
|
image_start_tokens += 1
|
|
|
|
|
image_end_tokens = torch.where(end_cond)[0]
|
|
|
|
|
@@ -1414,6 +1466,16 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
mm_inputs.update({"image_bound": image_bounds_list})
|
|
|
|
|
|
|
|
|
|
if len(audios) > 0:
|
|
|
|
|
audio_start_id = self._resolve_token_id(tokenizer, "audio_start_id", "<audio>")
|
|
|
|
|
audio_end_id = self._resolve_token_id(tokenizer, "audio_end_id", "</audio>")
|
|
|
|
|
spk_start_id = self._resolve_token_id(tokenizer, "spk_start_id", "<spk>")
|
|
|
|
|
spk_end_id = self._resolve_token_id(tokenizer, "spk_end_id", "</spk>")
|
|
|
|
|
if None in (audio_start_id, audio_end_id, spk_start_id, spk_end_id):
|
|
|
|
|
raise AttributeError(
|
|
|
|
|
"Cannot resolve MiniCPM audio/speaker boundary token ids from tokenizer. "
|
|
|
|
|
"Expected *_id attributes or corresponding special tokens."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# audio bound
|
|
|
|
|
audio_bounds_ls = []
|
|
|
|
|
spk_bounds_ls = []
|
|
|
|
|
@@ -1421,15 +1483,15 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
|
|
|
|
|
for input_ids, audiolen in zip(batch_ids, audlens):
|
|
|
|
|
input_ids_ = torch.tensor(input_ids)
|
|
|
|
|
audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0]
|
|
|
|
|
audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0]
|
|
|
|
|
audio_start_idx = torch.where(input_ids_ == audio_start_id)[0]
|
|
|
|
|
audio_end_idx = torch.where(input_ids_ == audio_end_id)[0]
|
|
|
|
|
assert len(audio_start_idx) == len(audio_end_idx)
|
|
|
|
|
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
|
|
|
|
|
audio_bounds_ls.append(audio_bounds)
|
|
|
|
|
valid_audio_nums_ls.append(audiolen)
|
|
|
|
|
|
|
|
|
|
spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
|
|
|
|
|
spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
|
|
|
|
|
spk_start_idx = torch.where(input_ids_ == spk_start_id)[0]
|
|
|
|
|
spk_end_idx = torch.where(input_ids_ == spk_end_id)[0]
|
|
|
|
|
assert len(spk_start_idx) == len(spk_end_idx)
|
|
|
|
|
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
|
|
|
|
|
spk_bounds_ls.append(spk_bounds)
|
|
|
|
|
@@ -1441,6 +1503,255 @@ class MiniCPMVPlugin(BasePlugin):
|
|
|
|
|
return mm_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class MiniCPMV4_6Plugin(BasePlugin):
|
|
|
|
|
"""Plugin for MiniCPM-V-4.6 with new transformers (NaViT vision + get_placeholder_mask API)."""
|
|
|
|
|
|
|
|
|
|
def _get_mm_inputs(
|
|
|
|
|
self,
|
|
|
|
|
images: list["ImageInput"],
|
|
|
|
|
videos: list["VideoInput"],
|
|
|
|
|
audios: list["AudioInput"],
|
|
|
|
|
processor: "MMProcessor",
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> dict[str, "torch.Tensor"]:
|
|
|
|
|
image_processor = getattr(processor, "image_processor")
|
|
|
|
|
video_processor = getattr(processor, "video_processor", None)
|
|
|
|
|
mm_inputs = {}
|
|
|
|
|
preprocess_params = inspect.signature(image_processor.preprocess).parameters
|
|
|
|
|
downsample_mode = os.getenv("DOWNSAMPLE_MODE", "16x") if "downsample_mode" in preprocess_params else None
|
|
|
|
|
|
|
|
|
|
if len(images) != 0:
|
|
|
|
|
images = self._regularize_images(
|
|
|
|
|
images,
|
|
|
|
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
|
|
|
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
|
|
|
|
)["images"]
|
|
|
|
|
image_processor_kwargs = {
|
|
|
|
|
"max_slice_nums": getattr(image_processor, "max_slice_nums", 9),
|
|
|
|
|
"return_tensors": "pt",
|
|
|
|
|
}
|
|
|
|
|
if downsample_mode is not None:
|
|
|
|
|
image_processor_kwargs["downsample_mode"] = downsample_mode
|
|
|
|
|
image_inputs = image_processor(images, **image_processor_kwargs)
|
|
|
|
|
mm_inputs.update(image_inputs)
|
|
|
|
|
|
|
|
|
|
if len(videos) != 0:
|
|
|
|
|
videos = self._regularize_videos(
|
|
|
|
|
videos,
|
|
|
|
|
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
|
|
|
|
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
|
|
|
|
video_fps=getattr(processor, "video_fps", 2.0),
|
|
|
|
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
|
|
|
|
)["videos"]
|
|
|
|
|
if video_processor is not None:
|
|
|
|
|
video_processor_kwargs = {
|
|
|
|
|
"max_slice_nums": 2,
|
|
|
|
|
"return_tensors": "pt",
|
|
|
|
|
}
|
|
|
|
|
if downsample_mode is not None:
|
|
|
|
|
video_processor_kwargs["downsample_mode"] = downsample_mode
|
|
|
|
|
video_inputs = video_processor(videos, **video_processor_kwargs)
|
|
|
|
|
mm_inputs["pixel_values_videos"] = video_inputs["pixel_values_videos"]
|
|
|
|
|
mm_inputs["target_sizes_videos"] = video_inputs["target_sizes_videos"]
|
|
|
|
|
else:
|
|
|
|
|
# Fallback to image processor for video
|
|
|
|
|
video_processor_kwargs = {
|
|
|
|
|
"max_slice_nums": 2,
|
|
|
|
|
"return_tensors": "pt",
|
|
|
|
|
}
|
|
|
|
|
if downsample_mode is not None:
|
|
|
|
|
video_processor_kwargs["downsample_mode"] = downsample_mode
|
|
|
|
|
video_inputs = image_processor(videos, **video_processor_kwargs)
|
|
|
|
|
mm_inputs["pixel_values_videos"] = video_inputs["pixel_values"]
|
|
|
|
|
mm_inputs["target_sizes_videos"] = video_inputs["target_sizes"]
|
|
|
|
|
|
|
|
|
|
if len(audios) != 0:
|
|
|
|
|
audios = self._regularize_audios(
|
|
|
|
|
audios,
|
|
|
|
|
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
|
|
|
|
)["audios"]
|
|
|
|
|
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
|
|
|
|
|
[audios],
|
|
|
|
|
chunk_input=True,
|
|
|
|
|
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
|
|
|
|
)
|
|
|
|
|
audio_feature_lens = [
|
|
|
|
|
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
|
|
|
|
|
]
|
|
|
|
|
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
|
|
|
|
if kwargs.get("ret_phs", False):
|
|
|
|
|
mm_inputs.update({"audio_phs": audio_phs})
|
|
|
|
|
|
|
|
|
|
return mm_inputs
|
|
|
|
|
|
|
|
|
|
def _build_v4_6_placeholder(
|
|
|
|
|
self,
|
|
|
|
|
image_inputs: dict[str, Any],
|
|
|
|
|
image_idx: int,
|
|
|
|
|
use_image_id: bool,
|
|
|
|
|
processor: "MMProcessor",
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Build image placeholder for MiniCPM-V-4.6 using NaViT token count computation."""
|
|
|
|
|
grids = image_inputs.get("grids", [[0, 0]])
|
|
|
|
|
num_patches_per_image = image_inputs.get("num_patches_per_image", [1])
|
|
|
|
|
target_sizes = image_inputs.get("target_sizes")
|
|
|
|
|
|
|
|
|
|
downsample_mode = os.getenv("DOWNSAMPLE_MODE")
|
|
|
|
|
if downsample_mode is None:
|
|
|
|
|
image_processor = getattr(processor, "image_processor")
|
|
|
|
|
downsample_mode = getattr(image_processor, "downsample_mode", "16x")
|
|
|
|
|
token_divisor = 4 if downsample_mode == "4x" else 16
|
|
|
|
|
|
|
|
|
|
flat_index = 0
|
|
|
|
|
for idx in range(image_idx):
|
|
|
|
|
flat_index += num_patches_per_image[idx]
|
|
|
|
|
n_patches = num_patches_per_image[image_idx]
|
|
|
|
|
|
|
|
|
|
img_target_sizes = target_sizes[flat_index : flat_index + n_patches]
|
|
|
|
|
num_tokens_per_patch = img_target_sizes.prod(-1) // token_divisor
|
|
|
|
|
num_rows, num_cols = grids[image_idx]
|
|
|
|
|
|
|
|
|
|
image_start = getattr(processor, "image_start_token", "<image>")
|
|
|
|
|
image_end = getattr(processor, "image_end_token", "</image>")
|
|
|
|
|
slice_start = getattr(processor, "slice_start_token", "<slice>")
|
|
|
|
|
slice_end = getattr(processor, "slice_end_token", "</slice>")
|
|
|
|
|
image_id_start = getattr(processor, "image_id_start_token", "<image_id>")
|
|
|
|
|
image_id_end = getattr(processor, "image_id_end_token", "</image_id>")
|
|
|
|
|
image_token = (
|
|
|
|
|
getattr(processor, "image_token", None)
|
|
|
|
|
or getattr(getattr(processor, "tokenizer", None), "image_token", None)
|
|
|
|
|
or "<image>"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
image_placeholder = image_start + "<|ph|>" * int(num_tokens_per_patch[0]) + image_end
|
|
|
|
|
if use_image_id:
|
|
|
|
|
image_placeholder = f"{image_id_start}{image_idx}{image_id_end}" + image_placeholder
|
|
|
|
|
|
|
|
|
|
slice_mode = getattr(processor, "slice_mode", True)
|
|
|
|
|
if slice_mode and num_rows > 0 and num_cols > 0:
|
|
|
|
|
per_slice_tokens = int(num_tokens_per_patch[1]) if len(num_tokens_per_patch) > 1 else 0
|
|
|
|
|
slice_placeholder = slice_start + "<|ph|>" * per_slice_tokens + slice_end
|
|
|
|
|
slices = [slice_placeholder * num_cols for _ in range(num_rows)]
|
|
|
|
|
image_placeholder += "\n".join(slices)
|
|
|
|
|
|
|
|
|
|
return image_placeholder.replace("<|ph|>", image_token)
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def process_messages(
|
|
|
|
|
self,
|
|
|
|
|
messages: list[dict[str, str]],
|
|
|
|
|
images: list["ImageInput"],
|
|
|
|
|
videos: list["VideoInput"],
|
|
|
|
|
audios: list["AudioInput"],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> list[dict[str, str]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
self._validate_messages(messages, images, videos, audios)
|
|
|
|
|
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
|
|
|
|
messages = deepcopy(messages)
|
|
|
|
|
mm_inputs, audio_inputs = {}, {}
|
|
|
|
|
if len(images) != 0 and len(videos) != 0:
|
|
|
|
|
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
|
|
|
|
|
|
|
|
|
use_image_id = getattr(processor, "default_use_image_id", True)
|
|
|
|
|
|
|
|
|
|
if len(videos) != 0:
|
|
|
|
|
use_image_id = False
|
|
|
|
|
mm_inputs = self._get_mm_inputs([], videos, [], processor)
|
|
|
|
|
|
|
|
|
|
for i, message in enumerate(messages):
|
|
|
|
|
content = message["content"]
|
|
|
|
|
while IMAGE_PLACEHOLDER in content:
|
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
|
|
|
|
num_image_tokens += 1
|
|
|
|
|
|
|
|
|
|
while VIDEO_PLACEHOLDER in content:
|
|
|
|
|
num_frames = 1
|
|
|
|
|
if "num_frames_per_video" in mm_inputs:
|
|
|
|
|
num_frames = sum(mm_inputs["num_frames_per_video"])
|
|
|
|
|
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * num_frames, 1)
|
|
|
|
|
num_video_tokens += 1
|
|
|
|
|
|
|
|
|
|
while AUDIO_PLACEHOLDER in content:
|
|
|
|
|
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
|
|
|
|
num_audio_tokens += 1
|
|
|
|
|
|
|
|
|
|
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
|
|
|
|
|
"{{audio}}", "(<audio>./</audio>)"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if len(images):
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, [], [], processor)
|
|
|
|
|
|
|
|
|
|
if len(audios):
|
|
|
|
|
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
|
|
|
|
|
|
|
|
|
|
if self.expand_mm_tokens and mm_inputs:
|
|
|
|
|
pattern = "(<image>./</image>)"
|
|
|
|
|
idx = 0
|
|
|
|
|
for index, message in enumerate(messages):
|
|
|
|
|
text = message["content"]
|
|
|
|
|
image_tags = re.findall(pattern, text)
|
|
|
|
|
text_chunks = text.split(pattern)
|
|
|
|
|
final_text = ""
|
|
|
|
|
for i in range(len(image_tags)):
|
|
|
|
|
image_placeholder = self._build_v4_6_placeholder(mm_inputs, idx, use_image_id, processor)
|
|
|
|
|
final_text = final_text + text_chunks[i] + image_placeholder
|
|
|
|
|
idx += 1
|
|
|
|
|
final_text += text_chunks[-1]
|
|
|
|
|
messages[index]["content"] = final_text
|
|
|
|
|
|
|
|
|
|
if self.expand_mm_tokens and audio_inputs:
|
|
|
|
|
pattern = "(<audio>./</audio>)"
|
|
|
|
|
idx = 0
|
|
|
|
|
for index, message in enumerate(messages):
|
|
|
|
|
text = message["content"]
|
|
|
|
|
audio_tags = re.findall(pattern, text)
|
|
|
|
|
text_chunks = text.split(pattern)
|
|
|
|
|
final_text = ""
|
|
|
|
|
for i in range(len(audio_tags)):
|
|
|
|
|
audio_placeholder = audio_inputs["audio_phs"][0][idx]
|
|
|
|
|
final_text = final_text + text_chunks[i] + audio_placeholder
|
|
|
|
|
idx += 1
|
|
|
|
|
final_text += text_chunks[-1]
|
|
|
|
|
messages[index]["content"] = final_text
|
|
|
|
|
|
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def get_mm_inputs(
|
|
|
|
|
self,
|
|
|
|
|
images: list["ImageInput"],
|
|
|
|
|
videos: list["VideoInput"],
|
|
|
|
|
audios: list["AudioInput"],
|
|
|
|
|
imglens: list[int],
|
|
|
|
|
vidlens: list[int],
|
|
|
|
|
audlens: list[int],
|
|
|
|
|
batch_ids: list[list[int]],
|
|
|
|
|
processor: Optional["MMProcessor"],
|
|
|
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
|
|
|
self._validate_input(processor, images, videos, audios)
|
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
|
|
|
|
|
# v4.6 does NOT use image_bound — the model finds image tokens via get_placeholder_mask
|
|
|
|
|
# Ensure target_sizes key name matches the model's expected input
|
|
|
|
|
if "target_sizes" not in mm_inputs and "tgt_sizes" in mm_inputs:
|
|
|
|
|
mm_inputs["target_sizes"] = mm_inputs.pop("tgt_sizes")
|
|
|
|
|
|
|
|
|
|
if "target_sizes" not in mm_inputs:
|
|
|
|
|
mm_inputs["target_sizes"] = torch.empty(0, 2, dtype=torch.int32)
|
|
|
|
|
|
|
|
|
|
if "pixel_values" not in mm_inputs:
|
|
|
|
|
mm_inputs["pixel_values"] = torch.empty(1, 3, 14, 0)
|
|
|
|
|
|
|
|
|
|
if len(audios) > 0:
|
|
|
|
|
audio_inputs = self._get_mm_inputs([], [], audios, processor)
|
|
|
|
|
mm_inputs.update(audio_inputs)
|
|
|
|
|
|
|
|
|
|
return mm_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class MllamaPlugin(BasePlugin):
|
|
|
|
|
@override
|
|
|
|
|
@@ -2695,6 +3006,7 @@ PLUGINS = {
|
|
|
|
|
"llava_next_video": LlavaNextVideoPlugin,
|
|
|
|
|
"lfm2_vl": LFMVLPlugin,
|
|
|
|
|
"minicpm_v": MiniCPMVPlugin,
|
|
|
|
|
"minicpm_v_4_6": MiniCPMV4_6Plugin,
|
|
|
|
|
"mllama": MllamaPlugin,
|
|
|
|
|
"paligemma": PaliGemmaPlugin,
|
|
|
|
|
"pixtral": PixtralPlugin,
|
|
|
|
|
|