mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
[data] support nested images input for videos (#8264)
This commit is contained in:
parent
c4e51d40e0
commit
c224d17cb2
@ -51,12 +51,27 @@ class DatasetConverter:
|
|||||||
else:
|
else:
|
||||||
medias = medias[:]
|
medias = medias[:]
|
||||||
|
|
||||||
if self.dataset_attr.load_from in ["script", "file"] and isinstance(medias[0], str):
|
if self.dataset_attr.load_from in ["script", "file"]:
|
||||||
for i in range(len(medias)):
|
if isinstance(medias[0], str):
|
||||||
if os.path.isfile(os.path.join(self.data_args.media_dir, medias[i])):
|
for i in range(len(medias)):
|
||||||
medias[i] = os.path.join(self.data_args.media_dir, medias[i])
|
media_path = os.path.join(self.data_args.media_dir, medias[i])
|
||||||
else:
|
if os.path.isfile(media_path):
|
||||||
logger.warning_rank0_once(f"Media {medias[i]} does not exist in `media_dir`. Use original path.")
|
medias[i] = media_path
|
||||||
|
else:
|
||||||
|
logger.warning_rank0_once(
|
||||||
|
f"Media {medias[i]} does not exist in `media_dir`. Use original path."
|
||||||
|
)
|
||||||
|
elif isinstance(medias[0], list): # for processed video frames
|
||||||
|
# medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
|
||||||
|
for i in range(len(medias)):
|
||||||
|
for j in range(len(medias[i])):
|
||||||
|
media_path = os.path.join(self.data_args.media_dir, medias[i][j])
|
||||||
|
if os.path.isfile(media_path):
|
||||||
|
medias[i][j] = media_path
|
||||||
|
else:
|
||||||
|
logger.warning_rank0_once(
|
||||||
|
f"Media {medias[i][j]} does not exist in `media_dir`. Use original path."
|
||||||
|
)
|
||||||
|
|
||||||
return medias
|
return medias
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -25,7 +26,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from transformers.image_utils import get_image_size, to_numpy_array
|
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
@ -76,7 +77,7 @@ if TYPE_CHECKING:
|
|||||||
bytes: Optional[bytes]
|
bytes: Optional[bytes]
|
||||||
|
|
||||||
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
||||||
VideoInput = Union[str, BinaryIO]
|
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
||||||
AudioInput = Union[str, BinaryIO, NDArray]
|
AudioInput = Union[str, BinaryIO, NDArray]
|
||||||
|
|
||||||
class MMProcessor(ProcessorMixin):
|
class MMProcessor(ProcessorMixin):
|
||||||
@ -134,6 +135,11 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
|
|||||||
return batch_images
|
return batch_images
|
||||||
|
|
||||||
|
|
||||||
|
def _check_video_is_nested_images(video: "VideoInput") -> bool:
|
||||||
|
r"""Check if the video is nested images."""
|
||||||
|
return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MMPluginMixin:
|
class MMPluginMixin:
|
||||||
image_token: Optional[str]
|
image_token: Optional[str]
|
||||||
@ -266,14 +272,20 @@ class MMPluginMixin:
|
|||||||
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
||||||
results = []
|
results = []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
container = av.open(video, "r")
|
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
|
||||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
|
||||||
frames: list[ImageObject] = []
|
frames: list[ImageObject] = []
|
||||||
container.seek(0)
|
if _check_video_is_nested_images(video):
|
||||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
for frame in video:
|
||||||
if frame_idx in sample_indices:
|
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
||||||
frames.append(frame.to_image())
|
raise ValueError("Invalid image found in video frames.")
|
||||||
|
frames = video
|
||||||
|
else:
|
||||||
|
container = av.open(video, "r")
|
||||||
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||||
|
container.seek(0)
|
||||||
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
|
if frame_idx in sample_indices:
|
||||||
|
frames.append(frame.to_image())
|
||||||
|
|
||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
@ -1380,24 +1392,33 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
|
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
|
||||||
results, fps_per_video = [], []
|
results, fps_per_video = [], []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
container = av.open(video, "r")
|
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
|
||||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
|
||||||
frames: list[ImageObject] = []
|
frames: list[ImageObject] = []
|
||||||
container.seek(0)
|
if _check_video_is_nested_images(video):
|
||||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
for frame in video:
|
||||||
if frame_idx in sample_indices:
|
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
||||||
frames.append(frame.to_image())
|
raise ValueError("Invalid image found in video frames.")
|
||||||
|
|
||||||
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
|
frames = video
|
||||||
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
|
else:
|
||||||
|
container = av.open(video, "r")
|
||||||
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||||
|
container.seek(0)
|
||||||
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
|
if frame_idx in sample_indices:
|
||||||
|
frames.append(frame.to_image())
|
||||||
|
|
||||||
|
if video_stream.duration is None:
|
||||||
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
|
else:
|
||||||
|
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
||||||
|
|
||||||
|
if len(frames) % 2 != 0:
|
||||||
frames.append(frames[-1])
|
frames.append(frames[-1])
|
||||||
|
|
||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
if video_stream.duration is None:
|
|
||||||
fps_per_video.append(2.0)
|
|
||||||
else:
|
|
||||||
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
|
||||||
|
|
||||||
return {"videos": results, "fps_per_video": fps_per_video}
|
return {"videos": results, "fps_per_video": fps_per_video}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user