diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 034fcf99..215fe807 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -219,7 +219,7 @@ class MMPluginMixin: if total_frames == 0: # infinite video return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32) - sample_frames = math.floor(float(video_stream.duration * video_stream.time_base) * video_fps) + sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * video_fps)) sample_frames = min(total_frames, video_maxlen, sample_frames) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) @@ -588,7 +588,7 @@ class InternVLPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - if num_image_tokens >= len(image_pixel_patch_list): + if num_image_tokens >= len(images): raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") content = content.replace( @@ -599,7 +599,7 @@ class InternVLPlugin(BasePlugin): num_image_tokens += 1 while VIDEO_PLACEHOLDER in content: - if num_video_tokens >= len(video_patch_indices): + if num_video_tokens >= len(videos): raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0 @@ -657,7 +657,7 @@ class KimiVLPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - if num_image_tokens >= len(image_grid_hws): + if num_image_tokens >= len(images): raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 @@ -710,6 +710,9 @@ class Llama4Plugin(BasePlugin): for local_image_index, split_part in enumerate(prompt_splits): new_content.append(split_part) if local_image_index < placeholder_count: + if num_image_tokens >= len(images): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + tokens_for_this_image = processor._prompt_split_image( aspect_ratios[num_image_tokens], num_patches_per_chunk ) @@ -774,6 +777,9 @@ class LlavaPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: + if num_image_tokens >= len(images): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) num_image_tokens += 1 @@ -808,6 +814,9 @@ class LlavaNextPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: + if num_image_tokens >= len(images): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + if self.expand_mm_tokens: orig_height, orig_width = next(image_sizes) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) @@ -850,6 +859,9 @@ class LlavaNextVideoPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: + if num_image_tokens >= len(images): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + if self.expand_mm_tokens: orig_height, orig_width = next(image_sizes) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) @@ -876,6 +888,9 @@ class LlavaNextVideoPlugin(BasePlugin): for message in messages: content = message["content"] while VIDEO_PLACEHOLDER in content: + if num_video_tokens >= len(videos): + raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) num_video_tokens += 1 @@ -921,15 +936,24 @@ class MiniCPMVPlugin(BasePlugin): for i, message in enumerate(messages): content = message["content"] while IMAGE_PLACEHOLDER in content: + if num_image_tokens >= len(images): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) num_image_tokens += 1 while VIDEO_PLACEHOLDER in content: + if num_video_tokens >= len(videos): + raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") + video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) num_video_tokens += 1 while AUDIO_PLACEHOLDER in content: + if num_audio_tokens >= len(audios): + raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.") + content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1) num_audio_tokens += 1 @@ -1283,6 +1307,9 @@ class PixtralPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: + if num_image_tokens >= len(images): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + if self.expand_mm_tokens: height, width = next(image_sizes) num_height_tokens = height // processor.patch_size @@ -1350,6 +1377,9 @@ class Qwen2AudioPlugin(BasePlugin): for message in messages: content = message["content"] while AUDIO_PLACEHOLDER in content: + if num_audio_tokens >= len(audios): + raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.") + if self.expand_mm_tokens: audio_length = audio_lengths.pop(0) input_length = (audio_length - 1) // 2 + 1 @@ -1490,7 +1520,7 @@ class Qwen2VLPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - if num_image_tokens >= len(image_grid_thw): + if num_image_tokens >= len(images): raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 @@ -1500,7 +1530,7 @@ class Qwen2VLPlugin(BasePlugin): num_image_tokens += 1 while VIDEO_PLACEHOLDER in content: - if num_video_tokens >= len(video_grid_thw): + if num_video_tokens >= len(videos): raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 @@ -1630,7 +1660,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): content = message["content"] # separate with audio-video while IMAGE_PLACEHOLDER in content: - if num_image_tokens >= len(image_grid_thw): + if num_image_tokens >= len(images): raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length @@ -1643,7 +1673,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): if not use_audio_in_video: while AUDIO_PLACEHOLDER in content: - if num_audio_tokens >= len(audio_lengths): + if num_audio_tokens >= len(audios): raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.") audio_token_replace_length = audio_lengths[num_audio_tokens] @@ -1656,7 +1686,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): # TODO handle video_input and use_audio_in_video while VIDEO_PLACEHOLDER in content: - if num_video_tokens >= len(video_grid_thw): + if num_video_tokens >= len(videos): raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length @@ -1667,7 +1697,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): else: # if use the audio of video # deal video token and audio token togather while VIDEO_PLACEHOLDER in content: - if num_video_tokens >= len(video_grid_thw): + if num_video_tokens >= len(videos): raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) @@ -1756,10 +1786,16 @@ class VideoLlavaPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: + if num_image_tokens >= len(images): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) num_image_tokens += 1 while VIDEO_PLACEHOLDER in content: + if num_video_tokens >= len(videos): + raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) num_video_tokens += 1 diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 1691a001..553e0285 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Union from typing_extensions import override from ..extras import logging -from ..extras.misc import check_version from .data_utils import Role from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .mm_plugin import get_mm_plugin @@ -518,9 +517,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: template = TEMPLATES[data_args.template] - if template.mm_plugin.__class__.__name__ != "BasePlugin": - check_version("transformers>=4.45.0") - if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index c87e796c..d327eecf 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -101,7 +101,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) patch_processor(processor, tokenizer, model_args) except Exception as e: - logger.debug(f"Failed to load processor: {e}.") + logger.info_rank0(f"Failed to load processor: {e}.") processor = None # Avoid load tokenizer, see: