mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
Tiny fix.
Former-commit-id: 8372c5e3771c42f225d7bd80a758af920f80e893
This commit is contained in:
parent
20faaf3418
commit
3295519099
@ -686,19 +686,12 @@ class MllamaPlugin(BasePlugin):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
num_image_tokens = 0
|
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
|
||||||
num_image_tokens += 1
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
|
|
||||||
|
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def get_mm_inputs(
|
def get_mm_inputs(
|
||||||
@ -710,7 +703,7 @@ class MllamaPlugin(BasePlugin):
|
|||||||
seqlens: Sequence[int],
|
seqlens: Sequence[int],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
super().get_mm_inputs(images, videos, imglens, vidlens, seqlens, processor)
|
self._get_mm_inputs(images, videos, processor)
|
||||||
if images is not None:
|
if images is not None:
|
||||||
images = [Image.open(image) if isinstance(image, str) else image for image in images]
|
images = [Image.open(image) if isinstance(image, str) else image for image in images]
|
||||||
image_features = processor.image_processor(images)
|
image_features = processor.image_processor(images)
|
||||||
|
@ -159,6 +159,8 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
|||||||
image_seqlen = config.vision_config.num_image_tokens
|
image_seqlen = config.vision_config.num_image_tokens
|
||||||
else:
|
else:
|
||||||
image_seqlen = -1
|
image_seqlen = -1
|
||||||
|
elif model_type == "mllama":
|
||||||
|
image_seqlen = ((config.vision_config.image_size // config.vision_config.patch_size) ** 2 + 1) * config.vision_config.max_num_tiles
|
||||||
|
|
||||||
return image_seqlen
|
return image_seqlen
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user