From 32cb086be148733e771fd73b9a3343f79bfe8f35 Mon Sep 17 00:00:00 2001 From: Kingsley <82590017+Kuangdd01@users.noreply.github.com> Date: Wed, 2 Apr 2025 23:58:39 +0800 Subject: [PATCH] [data] fix qwen2.5 omni plugin (#7578) * specific entry * Update mm_plugin.py * fix fps cal --------- Co-authored-by: hoshi-hiyouga --- src/llamafactory/data/collator.py | 2 ++ src/llamafactory/data/mm_plugin.py | 33 ++++++++---------------------- 2 files changed, 10 insertions(+), 25 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index ca5d5492..b79d22b4 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -192,6 +192,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): } if "second_per_grid_ts" in mm_inputs: # for qwen2vl rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") + if "video_second_per_grid" in mm_inputs: # for qwen2omni + rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid") if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni feature_attention_mask = mm_inputs.get("feature_attention_mask", None) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index aae240b3..d5114eec 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1179,7 +1179,9 @@ class Qwen2VLPlugin(BasePlugin): video_maxlen=getattr(processor, "video_maxlen", 128), ) mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt")) - mm_inputs["fps_per_video"] = video_data["fps_per_video"] + temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) + if "second_per_grid_ts" in processor.model_input_names: + mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]] return mm_inputs @@ -1238,28 +1240,6 @@ class Qwen2VLPlugin(BasePlugin): return messages - @override - def get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - imglens: list[int], - vidlens: list[int], - audlens: list[int], - batch_ids: list[list[int]], - processor: Optional["MMProcessor"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - fps_per_video = mm_inputs.pop("fps_per_video", []) - image_processor: BaseImageProcessor = getattr(processor, "image_processor") - temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) - if "second_per_grid_ts" in processor.model_input_names and fps_per_video: - mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in fps_per_video] - - return mm_inputs - class Qwen2OmniPlugin(Qwen2VLPlugin): @override @@ -1290,7 +1270,10 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): video_maxlen=getattr(processor, "video_maxlen", 128), ) mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt")) - mm_inputs["fps_per_video"] = video_dict["fps_per_video"] + temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) + mm_inputs["video_second_per_grid"] = torch.tensor( + [temporal_patch_size / fps for fps in video_dict["fps_per_video"]] + ) if len(audios) != 0: audios = self._regularize_audios( @@ -1405,7 +1388,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size, ) .flatten() - * mm_inputs["second_per_grid_ts"][num_video_tokens] + * mm_inputs["video_second_per_grid"][num_video_tokens] * 25 # FIXME hardcode of position_id_per_seconds=25 ).long() t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]