mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
Update mm_plugin.py
Former-commit-id: 049c554aee25cf1e29bee88dfb21381b3a4a2947
This commit is contained in:
parent
54961946ac
commit
49054329d0
@ -457,19 +457,16 @@ class PixtralPlugin(BasePlugin):
|
|||||||
videos: Sequence["VideoInput"],
|
videos: Sequence["VideoInput"],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
patch_size = processor.patch_size
|
|
||||||
image_token = processor.image_token
|
|
||||||
image_break_token = processor.image_break_token
|
|
||||||
image_end_token = processor.image_end_token
|
|
||||||
|
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
|
patch_size = getattr(processor, "patch_size")
|
||||||
|
image_token = getattr(processor, "image_token")
|
||||||
|
image_break_token = getattr(processor, "image_break_token")
|
||||||
|
image_end_token = getattr(processor, "image_end_token")
|
||||||
|
|
||||||
num_image_tokens = 0
|
num_image_tokens = 0
|
||||||
img_kwargs = self._get_mm_inputs(images, videos, processor)
|
|
||||||
image_input_sizes = None
|
|
||||||
|
|
||||||
image_input_sizes = img_kwargs.get("image_sizes", None)
|
|
||||||
|
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
image_input_sizes = mm_inputs.get("image_sizes", None)
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
@ -483,12 +480,10 @@ class PixtralPlugin(BasePlugin):
|
|||||||
num_height_tokens = height // patch_size
|
num_height_tokens = height // patch_size
|
||||||
num_width_tokens = width // patch_size
|
num_width_tokens = width // patch_size
|
||||||
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||||
# Flatten list
|
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||||
replace_tokens = [item for sublist in replace_tokens for item in sublist]
|
|
||||||
replace_tokens[-1] = image_end_token
|
replace_tokens[-1] = image_end_token
|
||||||
replace_str = "".join(replace_tokens)
|
replace_str = "".join(replace_tokens)
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||||
|
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
Loading…
x
Reference in New Issue
Block a user