mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[data] support nested images input for videos (#8264)
This commit is contained in:
parent
6cc247e815
commit
3425bc6e71
@ -51,12 +51,27 @@ class DatasetConverter:
|
||||
else:
|
||||
medias = medias[:]
|
||||
|
||||
if self.dataset_attr.load_from in ["script", "file"] and isinstance(medias[0], str):
|
||||
for i in range(len(medias)):
|
||||
if os.path.isfile(os.path.join(self.data_args.media_dir, medias[i])):
|
||||
medias[i] = os.path.join(self.data_args.media_dir, medias[i])
|
||||
else:
|
||||
logger.warning_rank0_once(f"Media {medias[i]} does not exist in `media_dir`. Use original path.")
|
||||
if self.dataset_attr.load_from in ["script", "file"]:
|
||||
if isinstance(medias[0], str):
|
||||
for i in range(len(medias)):
|
||||
media_path = os.path.join(self.data_args.media_dir, medias[i])
|
||||
if os.path.isfile(media_path):
|
||||
medias[i] = media_path
|
||||
else:
|
||||
logger.warning_rank0_once(
|
||||
f"Media {medias[i]} does not exist in `media_dir`. Use original path."
|
||||
)
|
||||
elif isinstance(medias[0], list): # for processed video frames
|
||||
# medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
|
||||
for i in range(len(medias)):
|
||||
for j in range(len(medias[i])):
|
||||
media_path = os.path.join(self.data_args.media_dir, medias[i][j])
|
||||
if os.path.isfile(media_path):
|
||||
medias[i][j] = media_path
|
||||
else:
|
||||
logger.warning_rank0_once(
|
||||
f"Media {medias[i][j]} does not exist in `media_dir`. Use original path."
|
||||
)
|
||||
|
||||
return medias
|
||||
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
@ -25,7 +26,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.image_utils import get_image_size, to_numpy_array
|
||||
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
@ -76,7 +77,7 @@ if TYPE_CHECKING:
|
||||
bytes: Optional[bytes]
|
||||
|
||||
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
||||
VideoInput = Union[str, BinaryIO]
|
||||
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
||||
AudioInput = Union[str, BinaryIO, NDArray]
|
||||
|
||||
class MMProcessor(ProcessorMixin):
|
||||
@ -134,6 +135,11 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
|
||||
return batch_images
|
||||
|
||||
|
||||
def _check_video_is_nested_images(video: "VideoInput") -> bool:
|
||||
r"""Check if the video is nested images."""
|
||||
return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMPluginMixin:
|
||||
image_token: Optional[str]
|
||||
@ -266,14 +272,20 @@ class MMPluginMixin:
|
||||
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
||||
results = []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
frames: list[ImageObject] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
if _check_video_is_nested_images(video):
|
||||
for frame in video:
|
||||
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
||||
raise ValueError("Invalid image found in video frames.")
|
||||
frames = video
|
||||
else:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||
results.append(frames)
|
||||
@ -1380,24 +1392,33 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
|
||||
results, fps_per_video = [], []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
frames: list[ImageObject] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
if _check_video_is_nested_images(video):
|
||||
for frame in video:
|
||||
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
||||
raise ValueError("Invalid image found in video frames.")
|
||||
|
||||
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
|
||||
frames = video
|
||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||
else:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
if video_stream.duration is None:
|
||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||
else:
|
||||
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
||||
|
||||
if len(frames) % 2 != 0:
|
||||
frames.append(frames[-1])
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||
results.append(frames)
|
||||
if video_stream.duration is None:
|
||||
fps_per_video.append(2.0)
|
||||
else:
|
||||
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
||||
|
||||
return {"videos": results, "fps_per_video": fps_per_video}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user