diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 7103224b..92c673fd 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -686,19 +686,12 @@ class MllamaPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: self._validate_input(images, videos) - num_image_tokens = 0 messages = deepcopy(messages) for message in messages: content = message["content"] - while IMAGE_PLACEHOLDER in content: - num_image_tokens += 1 - content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1) - + content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1) message["content"] = content - if len(images) != num_image_tokens: - raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) - return messages def get_mm_inputs( @@ -710,7 +703,7 @@ class MllamaPlugin(BasePlugin): seqlens: Sequence[int], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - super().get_mm_inputs(images, videos, imglens, vidlens, seqlens, processor) + self._get_mm_inputs(images, videos, processor) if images is not None: images = [Image.open(image) if isinstance(image, str) else image for image in images] image_features = processor.image_processor(images) diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 1ac46e06..440ea8ce 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -159,6 +159,8 @@ def get_image_seqlen(config: "PretrainedConfig") -> int: image_seqlen = config.vision_config.num_image_tokens else: image_seqlen = -1 + elif model_type == "mllama": + image_seqlen = ((config.vision_config.image_size // config.vision_config.patch_size) ** 2 + 1) * config.vision_config.max_num_tiles return image_seqlen