diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index cfeecd86..48fbb3c5 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -194,7 +194,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid") - if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2.5 omni + if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]: # for qwen2.5 omni rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False) feature_attention_mask = mm_inputs.get("feature_attention_mask", None) if feature_attention_mask is not None: # FIXME: need to get video image lengths @@ -211,7 +211,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if ( self.model is not None and getattr(self.model.config, "model_type", None) - in ["glm4v", "Keye", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"] + in ["glm4v", "Keye", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"] and ("position_ids" not in features or features["position_ids"].dim() != 3) ): raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")