diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index eed8b4fc..d6e72159 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1,3 +1,4 @@ +import inspect import math import re from copy import deepcopy @@ -117,16 +118,19 @@ class BasePlugin: return image - def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int: + def _get_video_sample_indices(self, video_stream: "Stream", **kwargs) -> List[int]: r""" - Computes video sample frames according to fps. + Computes video sample indices according to fps. """ video_fps: float = kwargs["video_fps"] video_maxlen: int = kwargs["video_maxlen"] total_frames = video_stream.frames - sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps + if total_frames == 0: # infinite video + return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32) + + sample_frames = math.floor(float(video_stream.duration * video_stream.time_base) * video_fps) sample_frames = min(total_frames, video_maxlen, sample_frames) - return math.floor(sample_frames) + return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]: r""" @@ -159,9 +163,7 @@ class BasePlugin: for video in videos: container = av.open(video, "r") video_stream = next(stream for stream in container.streams if stream.type == "video") - total_frames = video_stream.frames - sample_frames = self._get_video_sample_frames(video_stream, **kwargs) - sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) + sample_indices = self._get_video_sample_indices(video_stream, **kwargs) frames: List["ImageObject"] = [] container.seek(0) for frame_idx, frame in enumerate(container.decode(video_stream)): @@ -228,7 +230,10 @@ class BasePlugin: video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), ) - mm_inputs.update(video_processor(videos, return_tensors="pt")) + if "videos" in inspect.signature(video_processor.preprocess).parameters: # qwen2vl processor + mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt")) + else: + mm_inputs.update(video_processor(videos, return_tensors="pt")) if len(audios) != 0: audios = self._regularize_audios( @@ -1011,9 +1016,7 @@ class Qwen2vlPlugin(BasePlugin): for video in videos: container = av.open(video, "r") video_stream = next(stream for stream in container.streams if stream.type == "video") - total_frames = video_stream.frames - sample_frames = self._get_video_sample_frames(video_stream, **kwargs) - sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) + sample_indices = self._get_video_sample_indices(video_stream, **kwargs) frames: List["ImageObject"] = [] container.seek(0) for frame_idx, frame in enumerate(container.decode(video_stream)): diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 3df6d60c..5470f791 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -2156,6 +2156,18 @@ register_model_group( register_model_group( models={ + "Qwen2-VL-2B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B", + DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B", + }, + "Qwen2-VL-7B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B", + DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-7B", + }, + "Qwen2-VL-72B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B", + DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-72B", + }, "Qwen2-VL-2B-Instruct": { DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct", DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct",