diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 26216c4a..f9fc2b72 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -230,9 +230,9 @@ class BasePlugin: video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), ) - if "videos" in inspect.signature(video_processor.preprocess).parameters: # qwen2vl processor + if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt")) - else: + else: # for llava_next_video mm_inputs.update(video_processor(videos, return_tensors="pt")) if len(audios) != 0: @@ -1017,8 +1017,8 @@ class Qwen2vlPlugin(BasePlugin): return image @override - def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: - results = [] + def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> Tuple[List[List["ImageObject"]], List[float]]: + results, fps_per_video = [], [] for video in videos: container = av.open(video, "r") video_stream = next(stream for stream in container.streams if stream.type == "video") @@ -1034,8 +1034,41 @@ class Qwen2vlPlugin(BasePlugin): frames = self._regularize_images(frames, **kwargs) results.append(frames) + if video_stream.duration is None: + fps_per_video.append(2.0) + else: + fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) - return results + return results, fps_per_video + + @override + def _get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + audios: Sequence["AudioInput"], + processor: "ProcessorMixin", + ) -> Dict[str, "torch.Tensor"]: + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_resolution=getattr(processor, "image_resolution", 768 * 768), + ) + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + videos, fps_per_video = self._regularize_videos( + videos, + image_resolution=getattr(processor, "video_resolution", 256 * 256), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt")) + mm_inputs["fps_per_video"] = fps_per_video + + return mm_inputs @override def process_messages( @@ -1101,12 +1134,10 @@ class Qwen2vlPlugin(BasePlugin): ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos, audios) mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + fps_per_video = mm_inputs.pop("fps_per_video", []) image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and "video_grid_thw" in mm_inputs: - video_fps = getattr(processor, "video_fps", 2.0) - mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / video_fps] * len( - mm_inputs["video_grid_thw"] - ) + if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and fps_per_video: + mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video] return mm_inputs