mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 08:02:51 +08:00
[data] improve mmplugin (#7795)
This commit is contained in:
parent
a62cba3d05
commit
92101f34a1
@ -219,7 +219,7 @@ class MMPluginMixin:
|
||||
if total_frames == 0: # infinite video
|
||||
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
|
||||
|
||||
sample_frames = math.floor(float(video_stream.duration * video_stream.time_base) * video_fps)
|
||||
sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * video_fps))
|
||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
|
||||
@ -588,7 +588,7 @@ class InternVLPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_pixel_patch_list):
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(
|
||||
@ -599,7 +599,7 @@ class InternVLPlugin(BasePlugin):
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_patch_indices):
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
|
||||
@ -657,7 +657,7 @@ class KimiVLPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_grid_hws):
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
@ -710,6 +710,9 @@ class Llama4Plugin(BasePlugin):
|
||||
for local_image_index, split_part in enumerate(prompt_splits):
|
||||
new_content.append(split_part)
|
||||
if local_image_index < placeholder_count:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
tokens_for_this_image = processor._prompt_split_image(
|
||||
aspect_ratios[num_image_tokens], num_patches_per_chunk
|
||||
)
|
||||
@ -774,6 +777,9 @@ class LlavaPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
@ -808,6 +814,9 @@ class LlavaNextPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if self.expand_mm_tokens:
|
||||
orig_height, orig_width = next(image_sizes)
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
@ -850,6 +859,9 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if self.expand_mm_tokens:
|
||||
orig_height, orig_width = next(image_sizes)
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
@ -876,6 +888,9 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
@ -921,15 +936,24 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
for i, message in enumerate(messages):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
if num_audio_tokens >= len(audios):
|
||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
||||
num_audio_tokens += 1
|
||||
|
||||
@ -1283,6 +1307,9 @@ class PixtralPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if self.expand_mm_tokens:
|
||||
height, width = next(image_sizes)
|
||||
num_height_tokens = height // processor.patch_size
|
||||
@ -1350,6 +1377,9 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
if num_audio_tokens >= len(audios):
|
||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
if self.expand_mm_tokens:
|
||||
audio_length = audio_lengths.pop(0)
|
||||
input_length = (audio_length - 1) // 2 + 1
|
||||
@ -1490,7 +1520,7 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_grid_thw):
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
@ -1500,7 +1530,7 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
@ -1630,7 +1660,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
content = message["content"]
|
||||
# separate with audio-video
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_grid_thw):
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
|
||||
@ -1643,7 +1673,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
|
||||
if not use_audio_in_video:
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
if num_audio_tokens >= len(audio_lengths):
|
||||
if num_audio_tokens >= len(audios):
|
||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
audio_token_replace_length = audio_lengths[num_audio_tokens]
|
||||
@ -1656,7 +1686,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
|
||||
# TODO handle video_input and use_audio_in_video
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
|
||||
@ -1667,7 +1697,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
|
||||
else: # if use the audio of video # deal video token and audio token togather
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||
@ -1756,10 +1786,16 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
|
@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import check_version
|
||||
from .data_utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .mm_plugin import get_mm_plugin
|
||||
@ -518,9 +517,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
|
||||
template = TEMPLATES[data_args.template]
|
||||
|
||||
if template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||
check_version("transformers>=4.45.0")
|
||||
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
|
@ -101,7 +101,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
patch_processor(processor, tokenizer, model_args)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load processor: {e}.")
|
||||
logger.info_rank0(f"Failed to load processor: {e}.")
|
||||
processor = None
|
||||
|
||||
# Avoid load tokenizer, see:
|
||||
|
Loading…
x
Reference in New Issue
Block a user