From 49054329d04786e4542c38d1608bec36215162e3 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 29 Oct 2024 22:16:22 +0800 Subject: [PATCH] Update mm_plugin.py Former-commit-id: 049c554aee25cf1e29bee88dfb21381b3a4a2947 --- src/llamafactory/data/mm_plugin.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index de348896..52c65cb7 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -457,19 +457,16 @@ class PixtralPlugin(BasePlugin): videos: Sequence["VideoInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - patch_size = processor.patch_size - image_token = processor.image_token - image_break_token = processor.image_break_token - image_end_token = processor.image_end_token - self._validate_input(images, videos) + patch_size = getattr(processor, "patch_size") + image_token = getattr(processor, "image_token") + image_break_token = getattr(processor, "image_break_token") + image_end_token = getattr(processor, "image_end_token") + num_image_tokens = 0 - img_kwargs = self._get_mm_inputs(images, videos, processor) - image_input_sizes = None - - image_input_sizes = img_kwargs.get("image_sizes", None) - messages = deepcopy(messages) + mm_inputs = self._get_mm_inputs(images, videos, processor) + image_input_sizes = mm_inputs.get("image_sizes", None) for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: @@ -483,12 +480,10 @@ class PixtralPlugin(BasePlugin): num_height_tokens = height // patch_size num_width_tokens = width // patch_size replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens - # Flatten list - replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list replace_tokens[-1] = image_end_token replace_str = "".join(replace_tokens) content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) - num_image_tokens += 1 message["content"] = content