diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 0f6f0973..91b801dc 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1885,8 +1885,14 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): image_grid_thw = mm_inputs.get("image_grid_thw", []) video_grid_thw = mm_inputs.get("video_grid_thw", []) if "feature_attention_mask" in mm_inputs: - input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 - audio_lengths = (input_lengths - 2) // 2 + 1 + if processor.__class__.__name__ == "Qwen3OmniMoeProcessor": # for qwen3omni + input_lengths = mm_inputs["feature_attention_mask"].sum(-1) + input_lengths_leave = input_lengths % 100 + feature_lengths = (input_lengths_leave - 1) // 2 + 1 + audio_lengths = ((feature_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + else: + input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 + audio_lengths = (input_lengths - 2) // 2 + 1 else: mm_inputs = {} image_grid_thw = [None] * len(images)