[data] fix qwen2.5 omni plugin (#7578)

* specific entry

* Update mm_plugin.py

* fix fps cal

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
Kingsley 2025-04-02 23:58:39 +08:00 committed by GitHub
parent 80f8d037d0
commit 32cb086be1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 25 deletions

View File

@ -192,6 +192,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
}
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
if "video_second_per_grid" in mm_inputs: # for qwen2omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)

View File

@ -1179,7 +1179,9 @@ class Qwen2VLPlugin(BasePlugin):
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
mm_inputs["fps_per_video"] = video_data["fps_per_video"]
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names:
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
return mm_inputs
@ -1238,28 +1240,6 @@ class Qwen2VLPlugin(BasePlugin):
return messages
@override
def get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in fps_per_video]
return mm_inputs
class Qwen2OmniPlugin(Qwen2VLPlugin):
@override
@ -1290,7 +1270,10 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
mm_inputs["fps_per_video"] = video_dict["fps_per_video"]
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
mm_inputs["video_second_per_grid"] = torch.tensor(
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
)
if len(audios) != 0:
audios = self._regularize_audios(
@ -1405,7 +1388,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
)
.flatten()
* mm_inputs["second_per_grid_ts"][num_video_tokens]
* mm_inputs["video_second_per_grid"][num_video_tokens]
* 25 # FIXME hardcode of position_id_per_seconds=25
).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]