[data] fix qwen2.5 omni collator (#7553)

This commit is contained in:
hoshi-hiyouga 2025-04-01 00:15:12 +08:00 committed by GitHub
parent 185c76f6ad
commit 2d421c57bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -193,7 +193,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni": # for qwen2omni
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(