Update mm_plugin.py

Former-commit-id: 049c554aee25cf1e29bee88dfb21381b3a4a2947
This commit is contained in:
hoshi-hiyouga 2024-10-29 22:16:22 +08:00 committed by GitHub
parent 54961946ac
commit 49054329d0

View File

@ -457,19 +457,16 @@ class PixtralPlugin(BasePlugin):
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
patch_size = processor.patch_size
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
self._validate_input(images, videos)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
image_end_token = getattr(processor, "image_end_token")
num_image_tokens = 0
img_kwargs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = None
image_input_sizes = img_kwargs.get("image_sizes", None)
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = mm_inputs.get("image_sizes", None)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
@ -483,12 +480,10 @@ class PixtralPlugin(BasePlugin):
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
# Flatten list
replace_tokens = [item for sublist in replace_tokens for item in sublist]
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
message["content"] = content