Tiny fix.

Former-commit-id: 8372c5e3771c42f225d7bd80a758af920f80e893
This commit is contained in:
marko1616 2024-09-26 11:06:21 -04:00 committed by hiyouga
parent 20faaf3418
commit 3295519099
2 changed files with 4 additions and 9 deletions

View File

@ -686,19 +686,12 @@ class MllamaPlugin(BasePlugin):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos) self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages return messages
def get_mm_inputs( def get_mm_inputs(
@ -710,7 +703,7 @@ class MllamaPlugin(BasePlugin):
seqlens: Sequence[int], seqlens: Sequence[int],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
super().get_mm_inputs(images, videos, imglens, vidlens, seqlens, processor) self._get_mm_inputs(images, videos, processor)
if images is not None: if images is not None:
images = [Image.open(image) if isinstance(image, str) else image for image in images] images = [Image.open(image) if isinstance(image, str) else image for image in images]
image_features = processor.image_processor(images) image_features = processor.image_processor(images)

View File

@ -159,6 +159,8 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
image_seqlen = config.vision_config.num_image_tokens image_seqlen = config.vision_config.num_image_tokens
else: else:
image_seqlen = -1 image_seqlen = -1
elif model_type == "mllama":
image_seqlen = ((config.vision_config.image_size // config.vision_config.patch_size) ** 2 + 1) * config.vision_config.max_num_tiles
return image_seqlen return image_seqlen