mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
Tiny fix.
Former-commit-id: 8372c5e3771c42f225d7bd80a758af920f80e893
This commit is contained in:
parent
20faaf3418
commit
3295519099
@ -686,19 +686,12 @@ class MllamaPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
return messages
|
||||
|
||||
def get_mm_inputs(
|
||||
@ -710,7 +703,7 @@ class MllamaPlugin(BasePlugin):
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
super().get_mm_inputs(images, videos, imglens, vidlens, seqlens, processor)
|
||||
self._get_mm_inputs(images, videos, processor)
|
||||
if images is not None:
|
||||
images = [Image.open(image) if isinstance(image, str) else image for image in images]
|
||||
image_features = processor.image_processor(images)
|
||||
|
@ -159,6 +159,8 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
image_seqlen = config.vision_config.num_image_tokens
|
||||
else:
|
||||
image_seqlen = -1
|
||||
elif model_type == "mllama":
|
||||
image_seqlen = ((config.vision_config.image_size // config.vision_config.patch_size) ** 2 + 1) * config.vision_config.max_num_tiles
|
||||
|
||||
return image_seqlen
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user