[data] improve mmplugin (#7795)

This commit is contained in:
hoshi-hiyouga 2025-04-22 01:25:33 +08:00 committed by GitHub
parent a62cba3d05
commit 92101f34a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 15 deletions

View File

@ -219,7 +219,7 @@ class MMPluginMixin:
if total_frames == 0: # infinite video if total_frames == 0: # infinite video
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32) return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
sample_frames = math.floor(float(video_stream.duration * video_stream.time_base) * video_fps) sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * video_fps))
sample_frames = min(total_frames, video_maxlen, sample_frames) sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
@ -588,7 +588,7 @@ class InternVLPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_pixel_patch_list): if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace( content = content.replace(
@ -599,7 +599,7 @@ class InternVLPlugin(BasePlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_patch_indices): if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0 current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
@ -657,7 +657,7 @@ class KimiVLPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_hws): if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
@ -710,6 +710,9 @@ class Llama4Plugin(BasePlugin):
for local_image_index, split_part in enumerate(prompt_splits): for local_image_index, split_part in enumerate(prompt_splits):
new_content.append(split_part) new_content.append(split_part)
if local_image_index < placeholder_count: if local_image_index < placeholder_count:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
tokens_for_this_image = processor._prompt_split_image( tokens_for_this_image = processor._prompt_split_image(
aspect_ratios[num_image_tokens], num_patches_per_chunk aspect_ratios[num_image_tokens], num_patches_per_chunk
) )
@ -774,6 +777,9 @@ class LlavaPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1 num_image_tokens += 1
@ -808,6 +814,9 @@ class LlavaNextPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes) orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
@ -850,6 +859,9 @@ class LlavaNextVideoPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes) orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
@ -876,6 +888,9 @@ class LlavaNextVideoPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1
@ -921,15 +936,24 @@ class MiniCPMVPlugin(BasePlugin):
for i, message in enumerate(messages): for i, message in enumerate(messages):
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1) content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
num_audio_tokens += 1 num_audio_tokens += 1
@ -1283,6 +1307,9 @@ class PixtralPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
height, width = next(image_sizes) height, width = next(image_sizes)
num_height_tokens = height // processor.patch_size num_height_tokens = height // processor.patch_size
@ -1350,6 +1377,9 @@ class Qwen2AudioPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
audio_length = audio_lengths.pop(0) audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1 input_length = (audio_length - 1) // 2 + 1
@ -1490,7 +1520,7 @@ class Qwen2VLPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw): if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
@ -1500,7 +1530,7 @@ class Qwen2VLPlugin(BasePlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw): if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
@ -1630,7 +1660,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
content = message["content"] content = message["content"]
# separate with audio-video # separate with audio-video
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw): if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
@ -1643,7 +1673,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if not use_audio_in_video: if not use_audio_in_video:
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audio_lengths): if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.") raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
audio_token_replace_length = audio_lengths[num_audio_tokens] audio_token_replace_length = audio_lengths[num_audio_tokens]
@ -1656,7 +1686,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
# TODO handle video_input and use_audio_in_video # TODO handle video_input and use_audio_in_video
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw): if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
@ -1667,7 +1697,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
else: # if use the audio of video # deal video token and audio token togather else: # if use the audio of video # deal video token and audio token togather
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw): if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
@ -1756,10 +1786,16 @@ class VideoLlavaPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1

View File

@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Union
from typing_extensions import override from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.misc import check_version
from .data_utils import Role from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin from .mm_plugin import get_mm_plugin
@ -518,9 +517,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template = TEMPLATES[data_args.template] template = TEMPLATES[data_args.template]
if template.mm_plugin.__class__.__name__ != "BasePlugin":
check_version("transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")

View File

@ -101,7 +101,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, tokenizer, model_args) patch_processor(processor, tokenizer, model_args)
except Exception as e: except Exception as e:
logger.debug(f"Failed to load processor: {e}.") logger.info_rank0(f"Failed to load processor: {e}.")
processor = None processor = None
# Avoid load tokenizer, see: # Avoid load tokenizer, see: