diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index a83dab20..8b5d23d9 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1602,34 +1602,30 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) + num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + + merge_length = processor.image_processor.merge_size**2 + use_audio_in_video = getattr(processor, "use_audio_in_video", False) if self.expand_mm_tokens: mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + if "feature_attention_mask" in mm_inputs: + input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 + audio_lengths = (input_lengths - 2) // 2 + 1 else: mm_inputs = {} + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + audio_lengths = [None] * len(audios) - image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) - num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0 - use_audio_in_video = getattr(processor, "use_audio_in_video", False) - - # get length or size from mm_inputs - if "feature_attention_mask" in mm_inputs: - input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 - audio_lengths = (input_lengths - 2) // 2 + 1 - - if mm_inputs.get("image_grid_thw", None) is not None: - image_grid_thw = mm_inputs["image_grid_thw"] - merge_length = processor.image_processor.merge_size**2 - - if mm_inputs.get("video_grid_thw", None) is not None: - video_grid_thw = mm_inputs["video_grid_thw"] - merge_length = processor.image_processor.merge_size**2 - - if use_audio_in_video: - if audio_lengths is None: + if self.expand_mm_tokens and use_audio_in_video: + if "feature_attention_mask" not in mm_inputs: raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.") - if mm_inputs.get("video_grid_thw", None) is None: + if "video_grid_thw" not in mm_inputs: raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.") positions_list = [] @@ -1653,11 +1649,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): if num_image_tokens >= len(images): raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") - image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length + image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 content = content.replace( - IMAGE_PLACEHOLDER, - f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>", - 1, + IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1 ) num_image_tokens += 1 @@ -1666,11 +1660,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): if num_audio_tokens >= len(audios): raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.") - audio_token_replace_length = audio_lengths[num_audio_tokens] + audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1 content = content.replace( - AUDIO_PLACEHOLDER, - f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>", - 1, + AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1 ) num_audio_tokens += 1 @@ -1679,9 +1671,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): if num_video_tokens >= len(videos): raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") - video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length + video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 content = content.replace( - VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1 + VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1 ) num_video_tokens += 1