diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index a3636737..47682732 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -210,12 +210,17 @@ class Qwen2vlPlugin(BasePlugin): merge_length: int = getattr(image_processor, "merge_size") ** 2 if len(images) > 0: image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"] + else: + image_grid_thw = [] num_images = 0 messages = deepcopy(messages) for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: + if num_images >= len(image_grid_thw): + raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER)) + content = content.replace( IMAGE_PLACEHOLDER, "<|vision_start|>{}<|vision_end|>".format(