mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] Fix Qwen3VL plugin (#9297)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn> Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
		
							parent
							
								
									9c0d033a15
								
							
						
					
					
						commit
						129e918106
					
				@ -16,6 +16,7 @@ import gc
 | 
			
		||||
import json
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import av
 | 
			
		||||
import fire
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from transformers import Seq2SeqTrainingArguments
 | 
			
		||||
@ -33,6 +34,14 @@ if is_vllm_available():
 | 
			
		||||
    from vllm.lora.request import LoRARequest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _need_video_kwargs(template):
 | 
			
		||||
    NEEDED_TEMPLATE = ["qwen3_vl", "glm4v"]
 | 
			
		||||
    if any(t in template for t in NEEDED_TEMPLATE):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def vllm_infer(
 | 
			
		||||
    model_name_or_path: str,
 | 
			
		||||
    adapter_name_or_path: str = None,
 | 
			
		||||
@ -132,6 +141,7 @@ def vllm_infer(
 | 
			
		||||
 | 
			
		||||
    # Store all results in these lists
 | 
			
		||||
    all_prompts, all_preds, all_labels = [], [], []
 | 
			
		||||
    need_video_kwargs = _need_video_kwargs(template)
 | 
			
		||||
 | 
			
		||||
    # Add batch process to avoid the issue of too many files opened
 | 
			
		||||
    for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
 | 
			
		||||
@ -147,6 +157,7 @@ def vllm_infer(
 | 
			
		||||
                    )["images"]
 | 
			
		||||
                }
 | 
			
		||||
            elif batch["videos"][j] is not None:
 | 
			
		||||
                video_metadata, video_metadata_kwargs = None, None
 | 
			
		||||
                video = batch["videos"][j]
 | 
			
		||||
                multi_modal_data = {
 | 
			
		||||
                    "video": template_obj.mm_plugin._regularize_videos(
 | 
			
		||||
@ -157,6 +168,25 @@ def vllm_infer(
 | 
			
		||||
                        video_maxlen=video_maxlen,
 | 
			
		||||
                    )["videos"]
 | 
			
		||||
                }
 | 
			
		||||
                if need_video_kwargs:
 | 
			
		||||
                    container = av.open(video[0], "r")
 | 
			
		||||
                    video_stream = next(stream for stream in container.streams if stream.type == "video")
 | 
			
		||||
                    sampling_indices = template_obj.mm_plugin._get_video_sample_indices(
 | 
			
		||||
                        video_stream, video_fps, video_maxlen
 | 
			
		||||
                    )
 | 
			
		||||
                    total_frames = video_stream.frames
 | 
			
		||||
                    video_metadata_kwargs = {
 | 
			
		||||
                        "fps": getattr(tokenizer_module["processor"], "video_fps", 24.0),
 | 
			
		||||
                        "do_sample_frames": False,
 | 
			
		||||
                        "total_num_frames": total_frames,
 | 
			
		||||
                    }
 | 
			
		||||
                    video_metadata = dict(
 | 
			
		||||
                        fps=video_fps,
 | 
			
		||||
                        frames_indices=sampling_indices,
 | 
			
		||||
                        total_num_frames=total_frames,
 | 
			
		||||
                        video_backend="opencv",
 | 
			
		||||
                    )
 | 
			
		||||
                    multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
 | 
			
		||||
            elif batch["audios"][j] is not None:
 | 
			
		||||
                audio = batch["audios"][j]
 | 
			
		||||
                audio_data = template_obj.mm_plugin._regularize_audios(
 | 
			
		||||
@ -167,7 +197,11 @@ def vllm_infer(
 | 
			
		||||
            else:
 | 
			
		||||
                multi_modal_data = None
 | 
			
		||||
 | 
			
		||||
            vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
 | 
			
		||||
            vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
 | 
			
		||||
            if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
 | 
			
		||||
                vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
 | 
			
		||||
 | 
			
		||||
            vllm_inputs.append(vllm_input_data)
 | 
			
		||||
            prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
 | 
			
		||||
            labels.append(
 | 
			
		||||
                tokenizer.decode(
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,7 @@ from transformers.models.mllama.processing_mllama import (
 | 
			
		||||
    convert_sparse_cross_attention_mask_to_dense,
 | 
			
		||||
    get_cross_attention_token_mask,
 | 
			
		||||
)
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
from typing_extensions import NotRequired, override
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.packages import (
 | 
			
		||||
@ -77,6 +77,18 @@ if TYPE_CHECKING:
 | 
			
		||||
    VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
 | 
			
		||||
    AudioInput = Union[str, BinaryIO, NDArray]
 | 
			
		||||
 | 
			
		||||
    class RegularizedImageOutput(TypedDict):
 | 
			
		||||
        images: list[ImageObject]
 | 
			
		||||
 | 
			
		||||
    class RegularizedVideoOutput(TypedDict):
 | 
			
		||||
        videos: list[list[ImageObject]]
 | 
			
		||||
        durations: list[float]
 | 
			
		||||
        fps_per_video: NotRequired[list[float]]
 | 
			
		||||
 | 
			
		||||
    class RegularizedAudioOutput(TypedDict):
 | 
			
		||||
        audios: list[NDArray]
 | 
			
		||||
        sampling_rates: list[float]
 | 
			
		||||
 | 
			
		||||
    class MMProcessor(ProcessorMixin):
 | 
			
		||||
        patch_size: int
 | 
			
		||||
        image_seq_length: int
 | 
			
		||||
@ -244,7 +256,7 @@ class MMPluginMixin:
 | 
			
		||||
        sample_frames = min(total_frames, video_maxlen, sample_frames)
 | 
			
		||||
        return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
 | 
			
		||||
 | 
			
		||||
    def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
 | 
			
		||||
    def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
 | 
			
		||||
        r"""Regularize images to avoid error. Including reading and pre-processing."""
 | 
			
		||||
        results = []
 | 
			
		||||
        for image in images:
 | 
			
		||||
@ -265,9 +277,10 @@ class MMPluginMixin:
 | 
			
		||||
 | 
			
		||||
        return {"images": results}
 | 
			
		||||
 | 
			
		||||
    def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
 | 
			
		||||
    def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
 | 
			
		||||
        r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
 | 
			
		||||
        results = []
 | 
			
		||||
        durations = []
 | 
			
		||||
        for video in videos:
 | 
			
		||||
            frames: list[ImageObject] = []
 | 
			
		||||
            if _check_video_is_nested_images(video):
 | 
			
		||||
@ -275,6 +288,7 @@ class MMPluginMixin:
 | 
			
		||||
                    if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
 | 
			
		||||
                        raise ValueError("Invalid image found in video frames.")
 | 
			
		||||
                frames = video
 | 
			
		||||
                durations.append(len(frames) / kwargs.get("video_fps", 2.0))
 | 
			
		||||
            else:
 | 
			
		||||
                container = av.open(video, "r")
 | 
			
		||||
                video_stream = next(stream for stream in container.streams if stream.type == "video")
 | 
			
		||||
@ -284,14 +298,19 @@ class MMPluginMixin:
 | 
			
		||||
                    if frame_idx in sample_indices:
 | 
			
		||||
                        frames.append(frame.to_image())
 | 
			
		||||
 | 
			
		||||
                if video_stream.duration is None:
 | 
			
		||||
                    durations.append(len(frames) / kwargs.get("video_fps", 2.0))
 | 
			
		||||
                else:
 | 
			
		||||
                    durations.append(float(video_stream.duration * video_stream.time_base))
 | 
			
		||||
 | 
			
		||||
            frames = self._regularize_images(frames, **kwargs)["images"]
 | 
			
		||||
            results.append(frames)
 | 
			
		||||
 | 
			
		||||
        return {"videos": results}
 | 
			
		||||
        return {"videos": results, "durations": durations}
 | 
			
		||||
 | 
			
		||||
    def _regularize_audios(
 | 
			
		||||
        self, audios: list["AudioInput"], sampling_rate: float, **kwargs
 | 
			
		||||
    ) -> dict[str, Union[list["NDArray"], list[float]]]:
 | 
			
		||||
    ) -> "RegularizedAudioOutput":
 | 
			
		||||
        r"""Regularizes audios to avoid error. Including reading and resampling."""
 | 
			
		||||
        results, sampling_rates = [], []
 | 
			
		||||
        for audio in audios:
 | 
			
		||||
@ -1418,10 +1437,8 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
        return image
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _regularize_videos(
 | 
			
		||||
        self, videos: list["VideoInput"], **kwargs
 | 
			
		||||
    ) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
 | 
			
		||||
        results, fps_per_video = [], []
 | 
			
		||||
    def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
 | 
			
		||||
        results, fps_per_video, durations = [], [], []
 | 
			
		||||
        for video in videos:
 | 
			
		||||
            frames: list[ImageObject] = []
 | 
			
		||||
            if _check_video_is_nested_images(video):
 | 
			
		||||
@ -1431,6 +1448,7 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
                frames = video
 | 
			
		||||
                fps_per_video.append(kwargs.get("video_fps", 2.0))
 | 
			
		||||
                durations.append(len(frames) / kwargs.get("video_fps", 2.0))
 | 
			
		||||
            else:
 | 
			
		||||
                container = av.open(video, "r")
 | 
			
		||||
                video_stream = next(stream for stream in container.streams if stream.type == "video")
 | 
			
		||||
@ -1442,8 +1460,10 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
                if video_stream.duration is None:
 | 
			
		||||
                    fps_per_video.append(kwargs.get("video_fps", 2.0))
 | 
			
		||||
                    durations.append(len(frames) / kwargs.get("video_fps", 2.0))
 | 
			
		||||
                else:
 | 
			
		||||
                    fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
 | 
			
		||||
                    durations.append(float(video_stream.duration * video_stream.time_base))
 | 
			
		||||
 | 
			
		||||
            if len(frames) % 2 != 0:
 | 
			
		||||
                frames.append(frames[-1])
 | 
			
		||||
@ -1451,7 +1471,7 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
            frames = self._regularize_images(frames, **kwargs)["images"]
 | 
			
		||||
            results.append(frames)
 | 
			
		||||
 | 
			
		||||
        return {"videos": results, "fps_per_video": fps_per_video}
 | 
			
		||||
        return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
@ -1565,8 +1585,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )
 | 
			
		||||
            video_metadata = [
 | 
			
		||||
                {"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)}
 | 
			
		||||
                for video in videos["videos"]
 | 
			
		||||
                {"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
 | 
			
		||||
                for video, duration in zip(videos["videos"], videos["durations"])
 | 
			
		||||
            ]
 | 
			
		||||
            mm_inputs.update(
 | 
			
		||||
                video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
 | 
			
		||||
@ -1622,27 +1642,27 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                metadata = video_metadata[idx]
 | 
			
		||||
                timestamps = processor._calculate_timestamps(
 | 
			
		||||
                    metadata.frames_indices,
 | 
			
		||||
                    metadata.fps,
 | 
			
		||||
                    video_processor.merge_size,
 | 
			
		||||
                )
 | 
			
		||||
                video_structure = ""
 | 
			
		||||
                for frame_index in range(num_frames):
 | 
			
		||||
                    video_seqlen = (
 | 
			
		||||
                        video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
 | 
			
		||||
                        if self.expand_mm_tokens
 | 
			
		||||
                        else 1
 | 
			
		||||
                if self.expand_mm_tokens:
 | 
			
		||||
                    metadata = video_metadata[idx]
 | 
			
		||||
                    timestamps = processor._calculate_timestamps(
 | 
			
		||||
                        metadata.frames_indices,
 | 
			
		||||
                        metadata.fps,
 | 
			
		||||
                        video_processor.merge_size,
 | 
			
		||||
                    )
 | 
			
		||||
                    timestamp_sec = timestamps[frame_index]
 | 
			
		||||
                    frame_structure = (
 | 
			
		||||
                        f"<{timestamp_sec:.1f} seconds>"
 | 
			
		||||
                        f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
 | 
			
		||||
                    )
 | 
			
		||||
                    video_structure += frame_structure
 | 
			
		||||
 | 
			
		||||
                if not self.expand_mm_tokens:
 | 
			
		||||
                    video_structure = ""
 | 
			
		||||
                    for frame_index in range(num_frames):
 | 
			
		||||
                        video_seqlen = (
 | 
			
		||||
                            video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
 | 
			
		||||
                            if self.expand_mm_tokens
 | 
			
		||||
                            else 1
 | 
			
		||||
                        )
 | 
			
		||||
                        timestamp_sec = timestamps[frame_index]
 | 
			
		||||
                        frame_structure = (
 | 
			
		||||
                            f"<{timestamp_sec:.1f} seconds>"
 | 
			
		||||
                            f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
 | 
			
		||||
                        )
 | 
			
		||||
                        video_structure += frame_structure
 | 
			
		||||
                else:
 | 
			
		||||
                    video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
 | 
			
		||||
 | 
			
		||||
                content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
 | 
			
		||||
@ -1684,7 +1704,8 @@ class GLM4VPlugin(Qwen2VLPlugin):
 | 
			
		||||
            )
 | 
			
		||||
            # prepare video metadata
 | 
			
		||||
            video_metadata = [
 | 
			
		||||
                {"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
 | 
			
		||||
                {"fps": 2, "duration": duration, "total_frames": len(video)}
 | 
			
		||||
                for video, duration in zip(video_data["videos"], video_data["durations"])
 | 
			
		||||
            ]
 | 
			
		||||
            mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user