From 32955190991902ac7b8454468f8fb7aa77a79b46 Mon Sep 17 00:00:00 2001 From: marko1616 Date: Thu, 26 Sep 2024 11:06:21 -0400 Subject: [PATCH] Tiny fix. Former-commit-id: 8372c5e3771c42f225d7bd80a758af920f80e893 --- src/llamafactory/data/mm_plugin.py | 11 ++--------- src/llamafactory/model/model_utils/visual.py | 2 ++ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 7103224b..92c673fd 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -686,19 +686,12 @@ class MllamaPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: self._validate_input(images, videos) - num_image_tokens = 0 messages = deepcopy(messages) for message in messages: content = message["content"] - while IMAGE_PLACEHOLDER in content: - num_image_tokens += 1 - content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1) - + content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1) message["content"] = content - if len(images) != num_image_tokens: - raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) - return messages def get_mm_inputs( @@ -710,7 +703,7 @@ class MllamaPlugin(BasePlugin): seqlens: Sequence[int], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - super().get_mm_inputs(images, videos, imglens, vidlens, seqlens, processor) + self._get_mm_inputs(images, videos, processor) if images is not None: images = [Image.open(image) if isinstance(image, str) else image for image in images] image_features = processor.image_processor(images) diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 1ac46e06..440ea8ce 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -159,6 +159,8 @@ def get_image_seqlen(config: "PretrainedConfig") -> int: image_seqlen = config.vision_config.num_image_tokens else: image_seqlen = -1 + elif model_type == "mllama": + image_seqlen = ((config.vision_config.image_size // config.vision_config.patch_size) ** 2 + 1) * config.vision_config.max_num_tiles return image_seqlen