[data] support nested images input for videos (#8264)

This commit is contained in:
Kingsley 2025-06-03 20:26:29 +08:00 committed by GitHub
parent 6cc247e815
commit 3425bc6e71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 27 deletions

View File

@ -51,12 +51,27 @@ class DatasetConverter:
else:
medias = medias[:]
if self.dataset_attr.load_from in ["script", "file"] and isinstance(medias[0], str):
for i in range(len(medias)):
if os.path.isfile(os.path.join(self.data_args.media_dir, medias[i])):
medias[i] = os.path.join(self.data_args.media_dir, medias[i])
else:
logger.warning_rank0_once(f"Media {medias[i]} does not exist in `media_dir`. Use original path.")
if self.dataset_attr.load_from in ["script", "file"]:
if isinstance(medias[0], str):
for i in range(len(medias)):
media_path = os.path.join(self.data_args.media_dir, medias[i])
if os.path.isfile(media_path):
medias[i] = media_path
else:
logger.warning_rank0_once(
f"Media {medias[i]} does not exist in `media_dir`. Use original path."
)
elif isinstance(medias[0], list): # for processed video frames
# medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
for i in range(len(medias)):
for j in range(len(medias[i])):
media_path = os.path.join(self.data_args.media_dir, medias[i][j])
if os.path.isfile(media_path):
medias[i][j] = media_path
else:
logger.warning_rank0_once(
f"Media {medias[i][j]} does not exist in `media_dir`. Use original path."
)
return medias

View File

@ -17,6 +17,7 @@
import inspect
import math
import os
import re
from copy import deepcopy
from dataclasses import dataclass
@ -25,7 +26,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import numpy as np
import torch
from transformers.image_utils import get_image_size, to_numpy_array
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@ -76,7 +77,7 @@ if TYPE_CHECKING:
bytes: Optional[bytes]
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
VideoInput = Union[str, BinaryIO]
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
AudioInput = Union[str, BinaryIO, NDArray]
class MMProcessor(ProcessorMixin):
@ -134,6 +135,11 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
return batch_images
def _check_video_is_nested_images(video: "VideoInput") -> bool:
r"""Check if the video is nested images."""
return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video)
@dataclass
class MMPluginMixin:
image_token: Optional[str]
@ -266,14 +272,20 @@ class MMPluginMixin:
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: list[ImageObject] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
if _check_video_is_nested_images(video):
for frame in video:
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
raise ValueError("Invalid image found in video frames.")
frames = video
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
@ -1380,24 +1392,33 @@ class Qwen2VLPlugin(BasePlugin):
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
results, fps_per_video = [], []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: list[ImageObject] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
if _check_video_is_nested_images(video):
for frame in video:
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
raise ValueError("Invalid image found in video frames.")
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0))
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
if video_stream.duration is None:
fps_per_video.append(kwargs.get("video_fps", 2.0))
else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
if len(frames) % 2 != 0:
frames.append(frames[-1])
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
if video_stream.duration is None:
fps_per_video.append(2.0)
else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
return {"videos": results, "fps_per_video": fps_per_video}