mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] fix qwen2.5 omni template (#7883)
This commit is contained in:
		
							parent
							
								
									3ae5da2a04
								
							
						
					
					
						commit
						db9559456c
					
				@ -1621,30 +1621,8 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
 | 
			
		||||
            video_grid_thw = [None] * len(videos)
 | 
			
		||||
            audio_lengths = [None] * len(audios)
 | 
			
		||||
 | 
			
		||||
        if self.expand_mm_tokens and use_audio_in_video:
 | 
			
		||||
            if "feature_attention_mask" not in mm_inputs:
 | 
			
		||||
                raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
 | 
			
		||||
 | 
			
		||||
            if "video_grid_thw" not in mm_inputs:
 | 
			
		||||
                raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
 | 
			
		||||
 | 
			
		||||
            positions_list = []
 | 
			
		||||
            for message in messages:  # get multimodal index when use_audio
 | 
			
		||||
                positions = []
 | 
			
		||||
                for special_token in [self.audio_token, self.image_token, self.video_token]:
 | 
			
		||||
                    start = 0
 | 
			
		||||
                    while True:
 | 
			
		||||
                        pos = message["content"].find(special_token, start)
 | 
			
		||||
                        if pos == -1:
 | 
			
		||||
                            break
 | 
			
		||||
                        positions.append((pos, special_token))
 | 
			
		||||
                        start = pos + len(special_token)
 | 
			
		||||
 | 
			
		||||
                positions_list.append(positions.sort(key=lambda x: x[0]))
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            # separate with audio-video
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                if num_image_tokens >= len(images):
 | 
			
		||||
                    raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
@ -1655,34 +1633,26 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
 | 
			
		||||
                )
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            if not use_audio_in_video:
 | 
			
		||||
                while AUDIO_PLACEHOLDER in content:
 | 
			
		||||
            if (
 | 
			
		||||
                use_audio_in_video and len(audios) and len(videos)
 | 
			
		||||
            ):  # if use the audio of video # deal video token and audio token togather
 | 
			
		||||
                if len(videos) != len(audios):
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"Number of videos ({len(videos)}) must match number of audios ({len(audios)}) when using audio in video."
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    if num_video_tokens >= len(videos):
 | 
			
		||||
                        raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
 | 
			
		||||
                    if num_audio_tokens >= len(audios):
 | 
			
		||||
                        raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
                    audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
 | 
			
		||||
                    content = content.replace(
 | 
			
		||||
                        AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
 | 
			
		||||
                    )
 | 
			
		||||
                    num_audio_tokens += 1
 | 
			
		||||
 | 
			
		||||
                # TODO handle video_input and use_audio_in_video
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    if num_video_tokens >= len(videos):
 | 
			
		||||
                        raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
                    video_seqlen = (
 | 
			
		||||
                        video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
 | 
			
		||||
                    )
 | 
			
		||||
                    content = content.replace(
 | 
			
		||||
                        VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
 | 
			
		||||
                    )
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
            else:  # if use the audio of video # deal video token and audio token togather
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    if num_video_tokens >= len(videos):
 | 
			
		||||
                        raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
 | 
			
		||||
                    video_pos = content.find(VIDEO_PLACEHOLDER)
 | 
			
		||||
                    audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos)
 | 
			
		||||
                    if audio_pos == -1 or audio_pos < video_pos:
 | 
			
		||||
                        raise ValueError(
 | 
			
		||||
                            f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
 | 
			
		||||
                    video_t_index = (
 | 
			
		||||
@ -1716,6 +1686,28 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
 | 
			
		||||
                    content = content.replace(AUDIO_PLACEHOLDER, "", 1)
 | 
			
		||||
                    num_audio_tokens += 1
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
            else:
 | 
			
		||||
                while AUDIO_PLACEHOLDER in content:
 | 
			
		||||
                    if num_audio_tokens >= len(audios):
 | 
			
		||||
                        raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
                    audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
 | 
			
		||||
                    content = content.replace(
 | 
			
		||||
                        AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
 | 
			
		||||
                    )
 | 
			
		||||
                    num_audio_tokens += 1
 | 
			
		||||
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    if num_video_tokens >= len(videos):
 | 
			
		||||
                        raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
                    video_seqlen = (
 | 
			
		||||
                        video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
 | 
			
		||||
                    )
 | 
			
		||||
                    content = content.replace(
 | 
			
		||||
                        VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
 | 
			
		||||
                    )
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,7 @@
 | 
			
		||||
import os
 | 
			
		||||
from typing import TYPE_CHECKING, Any
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
@ -43,11 +44,20 @@ MM_MESSAGES = [
 | 
			
		||||
    {"role": "assistant", "content": "A cat."},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
OMNI_MESSAGES = [
 | 
			
		||||
    {"role": "user", "content": "<image>What is in this image?"},
 | 
			
		||||
    {"role": "assistant", "content": "A cat."},
 | 
			
		||||
    {"role": "user", "content": "<audio>What is in this audio?"},
 | 
			
		||||
    {"role": "assistant", "content": "Nothing."},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
TEXT_MESSAGES = [
 | 
			
		||||
    {"role": "user", "content": "How are you"},
 | 
			
		||||
    {"role": "assistant", "content": "I am fine!"},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
AUDIOS = [np.zeros(1600)]
 | 
			
		||||
 | 
			
		||||
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
 | 
			
		||||
 | 
			
		||||
NO_IMAGES = []
 | 
			
		||||
@ -58,6 +68,8 @@ NO_AUDIOS = []
 | 
			
		||||
 | 
			
		||||
IMGLENS = [1]
 | 
			
		||||
 | 
			
		||||
AUDLENS = [1]
 | 
			
		||||
 | 
			
		||||
NO_IMGLENS = [0]
 | 
			
		||||
 | 
			
		||||
NO_VIDLENS = [0]
 | 
			
		||||
@ -76,6 +88,25 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
 | 
			
		||||
    return image_processor(images=IMAGES, return_tensors="pt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_omni_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
 | 
			
		||||
    mm_inputs = {}
 | 
			
		||||
    image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
 | 
			
		||||
    feature_extractor = getattr(processor, "feature_extractor", None)
 | 
			
		||||
 | 
			
		||||
    mm_inputs.update(image_processor(IMAGES, return_tensors="pt"))
 | 
			
		||||
    mm_inputs.update(
 | 
			
		||||
        feature_extractor(
 | 
			
		||||
            AUDIOS,
 | 
			
		||||
            sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
 | 
			
		||||
            return_attention_mask=True,
 | 
			
		||||
            padding="max_length",
 | 
			
		||||
            return_tensors="pt",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask")
 | 
			
		||||
    return mm_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _is_close(batch_a: dict[str, Any], batch_b: dict[str, Any]) -> None:
 | 
			
		||||
    assert batch_a.keys() == batch_b.keys()
 | 
			
		||||
    for key in batch_a.keys():
 | 
			
		||||
@ -104,6 +135,17 @@ def _check_plugin(
 | 
			
		||||
    expected_mm_inputs: dict[str, Any] = {},
 | 
			
		||||
    expected_no_mm_inputs: dict[str, Any] = {},
 | 
			
		||||
) -> None:
 | 
			
		||||
    # test omni_messages
 | 
			
		||||
    if plugin.__class__.__name__ == "Qwen2OmniPlugin":
 | 
			
		||||
        assert plugin.process_messages(OMNI_MESSAGES, IMAGES, NO_VIDEOS, AUDIOS, processor) == expected_mm_messages
 | 
			
		||||
        assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, AUDIOS, tokenizer, processor) == (
 | 
			
		||||
            expected_input_ids,
 | 
			
		||||
            expected_labels,
 | 
			
		||||
        )
 | 
			
		||||
        _is_close(
 | 
			
		||||
            plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
 | 
			
		||||
            expected_mm_inputs,
 | 
			
		||||
        )
 | 
			
		||||
    # test mm_messages
 | 
			
		||||
    if plugin.__class__.__name__ != "BasePlugin":
 | 
			
		||||
        assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
 | 
			
		||||
@ -279,6 +321,30 @@ def test_pixtral_plugin():
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.xfail(reason="Unknown error.")
 | 
			
		||||
def test_qwen2_omni_plugin():
 | 
			
		||||
    image_seqlen = 4
 | 
			
		||||
    audio_seqlen = 2
 | 
			
		||||
    tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B")
 | 
			
		||||
    qwen2_omni_plugin = get_mm_plugin(
 | 
			
		||||
        name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
 | 
			
		||||
    )
 | 
			
		||||
    check_inputs = {"plugin": qwen2_omni_plugin, **tokenizer_module}
 | 
			
		||||
    check_inputs["expected_mm_messages"] = [
 | 
			
		||||
        {
 | 
			
		||||
            key: (
 | 
			
		||||
                value.replace("<image>", f"<|vision_bos|>{'<|IMAGE|>' * image_seqlen}<|vision_eos|>").replace(
 | 
			
		||||
                    "<audio>", f"<|audio_bos|>{'<|AUDIO|>' * audio_seqlen}<|audio_eos|>"
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            for key, value in message.items()
 | 
			
		||||
        }
 | 
			
		||||
        for message in OMNI_MESSAGES
 | 
			
		||||
    ]
 | 
			
		||||
    check_inputs["expected_mm_inputs"] = _get_omni_inputs(tokenizer_module["processor"])
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_qwen2_vl_plugin():
 | 
			
		||||
    image_seqlen = 4
 | 
			
		||||
    tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
 | 
			
		||||
 | 
			
		||||
@ -1,2 +1,2 @@
 | 
			
		||||
# change if test fails or cache is outdated
 | 
			
		||||
0.9.3.104
 | 
			
		||||
0.9.3.105
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user