From c224d17cb27d8b96b3efa2a7a7c4835e889520bf Mon Sep 17 00:00:00 2001 From: Kingsley Date: Tue, 3 Jun 2025 20:26:29 +0800 Subject: [PATCH] [data] support nested images input for videos (#8264) --- src/llamafactory/data/converter.py | 27 ++++++++++--- src/llamafactory/data/mm_plugin.py | 63 ++++++++++++++++++++---------- 2 files changed, 63 insertions(+), 27 deletions(-) diff --git a/src/llamafactory/data/converter.py b/src/llamafactory/data/converter.py index f0c791b7..2a284f19 100644 --- a/src/llamafactory/data/converter.py +++ b/src/llamafactory/data/converter.py @@ -51,12 +51,27 @@ class DatasetConverter: else: medias = medias[:] - if self.dataset_attr.load_from in ["script", "file"] and isinstance(medias[0], str): - for i in range(len(medias)): - if os.path.isfile(os.path.join(self.data_args.media_dir, medias[i])): - medias[i] = os.path.join(self.data_args.media_dir, medias[i]) - else: - logger.warning_rank0_once(f"Media {medias[i]} does not exist in `media_dir`. Use original path.") + if self.dataset_attr.load_from in ["script", "file"]: + if isinstance(medias[0], str): + for i in range(len(medias)): + media_path = os.path.join(self.data_args.media_dir, medias[i]) + if os.path.isfile(media_path): + medias[i] = media_path + else: + logger.warning_rank0_once( + f"Media {medias[i]} does not exist in `media_dir`. Use original path." + ) + elif isinstance(medias[0], list): # for processed video frames + # medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]] + for i in range(len(medias)): + for j in range(len(medias[i])): + media_path = os.path.join(self.data_args.media_dir, medias[i][j]) + if os.path.isfile(media_path): + medias[i][j] = media_path + else: + logger.warning_rank0_once( + f"Media {medias[i][j]} does not exist in `media_dir`. Use original path." + ) return medias diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 40378f45..446fbe60 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -17,6 +17,7 @@ import inspect import math +import os import re from copy import deepcopy from dataclasses import dataclass @@ -25,7 +26,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union import numpy as np import torch -from transformers.image_utils import get_image_size, to_numpy_array +from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array from typing_extensions import override from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER @@ -76,7 +77,7 @@ if TYPE_CHECKING: bytes: Optional[bytes] ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject] - VideoInput = Union[str, BinaryIO] + VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] AudioInput = Union[str, BinaryIO, NDArray] class MMProcessor(ProcessorMixin): @@ -134,6 +135,11 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis return batch_images +def _check_video_is_nested_images(video: "VideoInput") -> bool: + r"""Check if the video is nested images.""" + return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video) + + @dataclass class MMPluginMixin: image_token: Optional[str] @@ -266,14 +272,20 @@ class MMPluginMixin: r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" results = [] for video in videos: - container = av.open(video, "r") - video_stream = next(stream for stream in container.streams if stream.type == "video") - sample_indices = self._get_video_sample_indices(video_stream, **kwargs) frames: list[ImageObject] = [] - container.seek(0) - for frame_idx, frame in enumerate(container.decode(video_stream)): - if frame_idx in sample_indices: - frames.append(frame.to_image()) + if _check_video_is_nested_images(video): + for frame in video: + if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): + raise ValueError("Invalid image found in video frames.") + frames = video + else: + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + sample_indices = self._get_video_sample_indices(video_stream, **kwargs) + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) frames = self._regularize_images(frames, **kwargs)["images"] results.append(frames) @@ -1380,24 +1392,33 @@ class Qwen2VLPlugin(BasePlugin): ) -> dict[str, Union[list[list["ImageObject"]], list[float]]]: results, fps_per_video = [], [] for video in videos: - container = av.open(video, "r") - video_stream = next(stream for stream in container.streams if stream.type == "video") - sample_indices = self._get_video_sample_indices(video_stream, **kwargs) frames: list[ImageObject] = [] - container.seek(0) - for frame_idx, frame in enumerate(container.decode(video_stream)): - if frame_idx in sample_indices: - frames.append(frame.to_image()) + if _check_video_is_nested_images(video): + for frame in video: + if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): + raise ValueError("Invalid image found in video frames.") - if len(frames) % 2 != 0: # qwen2-vl requires even number of frames + frames = video + fps_per_video.append(kwargs.get("video_fps", 2.0)) + else: + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + sample_indices = self._get_video_sample_indices(video_stream, **kwargs) + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) + + if video_stream.duration is None: + fps_per_video.append(kwargs.get("video_fps", 2.0)) + else: + fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) + + if len(frames) % 2 != 0: frames.append(frames[-1]) frames = self._regularize_images(frames, **kwargs)["images"] results.append(frames) - if video_stream.duration is None: - fps_per_video.append(2.0) - else: - fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) return {"videos": results, "fps_per_video": fps_per_video}