mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[data] fix qwen2.5 omni plugin (#7578)
* specific entry * Update mm_plugin.py * fix fps cal --------- Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
parent
80f8d037d0
commit
32cb086be1
@ -192,6 +192,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
}
|
||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
if "video_second_per_grid" in mm_inputs: # for qwen2omni
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||
|
||||
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
|
@ -1179,7 +1179,9 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
|
||||
mm_inputs["fps_per_video"] = video_data["fps_per_video"]
|
||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||
if "second_per_grid_ts" in processor.model_input_names:
|
||||
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@ -1238,28 +1240,6 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
fps_per_video = mm_inputs.pop("fps_per_video", [])
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
|
||||
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in fps_per_video]
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
@override
|
||||
@ -1290,7 +1270,10 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
|
||||
mm_inputs["fps_per_video"] = video_dict["fps_per_video"]
|
||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||
mm_inputs["video_second_per_grid"] = torch.tensor(
|
||||
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
|
||||
)
|
||||
|
||||
if len(audios) != 0:
|
||||
audios = self._regularize_audios(
|
||||
@ -1405,7 +1388,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
|
||||
)
|
||||
.flatten()
|
||||
* mm_inputs["second_per_grid_ts"][num_video_tokens]
|
||||
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||
* 25 # FIXME hardcode of position_id_per_seconds=25
|
||||
).long()
|
||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
||||
|
Loading…
x
Reference in New Issue
Block a user