diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 6846a78d..2837f69f 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -655,8 +655,10 @@ class KimiVLPlugin(BasePlugin): self._validate_messages(messages, images, videos, audios) if self.expand_mm_tokens: mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_hws = mm_inputs.get("image_grid_hws", []) + else: + image_grid_hws = [None] * len(images) - image_grid_hws = mm_inputs.get("image_grid_hws", []) num_image_tokens = 0 image_processor: BaseImageProcessor = getattr(processor, "image_processor") merge_length = math.prod(image_processor.merge_kernel_size)