[data] fix qwen_2_5_vl video processing (#6868)

* fix qwen_2_5_vl video processing

* Update mm_plugin.py

* Update mm_plugin.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 9153a7bd832cdae84b63a4d7d1f2b12239e84b61
This commit is contained in:
HJ 2025-02-11 16:14:50 +08:00 committed by GitHub
parent 703bb9cc18
commit 188f22d8a7

View File

@ -230,9 +230,9 @@ class BasePlugin:
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
if "videos" in inspect.signature(video_processor.preprocess).parameters: # qwen2vl processor
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
else:
else: # for llava_next_video
mm_inputs.update(video_processor(videos, return_tensors="pt"))
if len(audios) != 0:
@ -1017,8 +1017,8 @@ class Qwen2vlPlugin(BasePlugin):
return image
@override
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
results = []
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> Tuple[List[List["ImageObject"]], List[float]]:
results, fps_per_video = [], []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
@ -1034,8 +1034,41 @@ class Qwen2vlPlugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs)
results.append(frames)
if video_stream.duration is None:
fps_per_video.append(2.0)
else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
return results
return results, fps_per_video
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 768 * 768),
)
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos, fps_per_video = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 256 * 256),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt"))
mm_inputs["fps_per_video"] = fps_per_video
return mm_inputs
@override
def process_messages(
@ -1101,12 +1134,10 @@ class Qwen2vlPlugin(BasePlugin):
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and "video_grid_thw" in mm_inputs:
video_fps = getattr(processor, "video_fps", 2.0)
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / video_fps] * len(
mm_inputs["video_grid_thw"]
)
if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and fps_per_video:
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
return mm_inputs