mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
[data] fix qwen2.5 omni plugin (#7578)
* specific entry * Update mm_plugin.py * fix fps cal --------- Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
parent
80f8d037d0
commit
32cb086be1
@ -192,6 +192,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
}
|
}
|
||||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||||
|
if "video_second_per_grid" in mm_inputs: # for qwen2omni
|
||||||
|
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||||
|
|
||||||
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
|
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
|
||||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||||
|
@ -1179,7 +1179,9 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
|
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
|
||||||
mm_inputs["fps_per_video"] = video_data["fps_per_video"]
|
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||||
|
if "second_per_grid_ts" in processor.model_input_names:
|
||||||
|
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
|
||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
@ -1238,28 +1240,6 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
|
||||||
def get_mm_inputs(
|
|
||||||
self,
|
|
||||||
images: list["ImageInput"],
|
|
||||||
videos: list["VideoInput"],
|
|
||||||
audios: list["AudioInput"],
|
|
||||||
imglens: list[int],
|
|
||||||
vidlens: list[int],
|
|
||||||
audlens: list[int],
|
|
||||||
batch_ids: list[list[int]],
|
|
||||||
processor: Optional["MMProcessor"],
|
|
||||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
||||||
self._validate_input(processor, images, videos, audios)
|
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
||||||
fps_per_video = mm_inputs.pop("fps_per_video", [])
|
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
|
||||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
|
||||||
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
|
|
||||||
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in fps_per_video]
|
|
||||||
|
|
||||||
return mm_inputs
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2OmniPlugin(Qwen2VLPlugin):
|
class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||||
@override
|
@override
|
||||||
@ -1290,7 +1270,10 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
|
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
|
||||||
mm_inputs["fps_per_video"] = video_dict["fps_per_video"]
|
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||||
|
mm_inputs["video_second_per_grid"] = torch.tensor(
|
||||||
|
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
|
||||||
|
)
|
||||||
|
|
||||||
if len(audios) != 0:
|
if len(audios) != 0:
|
||||||
audios = self._regularize_audios(
|
audios = self._regularize_audios(
|
||||||
@ -1405,7 +1388,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
|
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
|
||||||
)
|
)
|
||||||
.flatten()
|
.flatten()
|
||||||
* mm_inputs["second_per_grid_ts"][num_video_tokens]
|
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||||
* 25 # FIXME hardcode of position_id_per_seconds=25
|
* 25 # FIXME hardcode of position_id_per_seconds=25
|
||||||
).long()
|
).long()
|
||||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user