mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
[data] improve mm plugin (#7910)
This commit is contained in:
parent
ae392e054c
commit
77c569e071
@ -167,16 +167,45 @@ class MMPluginMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.image_token is not None and processor is None:
|
if self.image_token is not None and processor is None:
|
||||||
raise ValueError("Processor was not found, please check and update your processor config.")
|
raise ValueError("Processor was not found, please check and update your model file.")
|
||||||
|
|
||||||
if self.image_token is not None and image_processor is None:
|
if self.image_token is not None and image_processor is None:
|
||||||
raise ValueError("Image processor was not found, please check and update your processor config.")
|
raise ValueError("Image processor was not found, please check and update your model file.")
|
||||||
|
|
||||||
if self.video_token is not None and video_processor is None:
|
if self.video_token is not None and video_processor is None:
|
||||||
raise ValueError("Video processor was not found, please check and update your processor config.")
|
raise ValueError("Video processor was not found, please check and update your model file.")
|
||||||
|
|
||||||
if self.audio_token is not None and feature_extractor is None:
|
if self.audio_token is not None and feature_extractor is None:
|
||||||
raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
|
raise ValueError("Audio feature extractor was not found, please check and update your model file.")
|
||||||
|
|
||||||
|
def _validate_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
):
|
||||||
|
r"""Validate if the number of images, videos and audios match the number of placeholders in messages."""
|
||||||
|
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||||
|
for message in messages:
|
||||||
|
num_image_tokens += message["content"].count(IMAGE_PLACEHOLDER)
|
||||||
|
num_video_tokens += message["content"].count(VIDEO_PLACEHOLDER)
|
||||||
|
num_audio_tokens += message["content"].count(AUDIO_PLACEHOLDER)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(
|
||||||
|
f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens in {messages}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(
|
||||||
|
f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens in {messages}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(audios) != num_audio_tokens:
|
||||||
|
raise ValueError(
|
||||||
|
f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens in {messages}."
|
||||||
|
)
|
||||||
|
|
||||||
def _preprocess_image(
|
def _preprocess_image(
|
||||||
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
|
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
|
||||||
@ -364,6 +393,7 @@ class BasePlugin(MMPluginMixin):
|
|||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
r"""Pre-process input messages before tokenization for VLMs."""
|
r"""Pre-process input messages before tokenization for VLMs."""
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def process_token_ids(
|
def process_token_ids(
|
||||||
@ -420,6 +450,7 @@ class Gemma3Plugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
num_image_tokens = 0
|
num_image_tokens = 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
boi_token: str = getattr(processor, "boi_token")
|
boi_token: str = getattr(processor, "boi_token")
|
||||||
@ -446,9 +477,6 @@ class Gemma3Plugin(BasePlugin):
|
|||||||
|
|
||||||
message["content"] = content.replace("{{image}}", image_str)
|
message["content"] = content.replace("{{image}}", image_str)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -566,8 +594,8 @@ class InternVLPlugin(BasePlugin):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
num_image_tokens = 0
|
self._validate_messages(messages, images, videos, audios)
|
||||||
num_video_tokens = 0
|
num_image_tokens, num_video_tokens = 0, 0
|
||||||
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
|
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
@ -579,9 +607,6 @@ class InternVLPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
IMAGE_PLACEHOLDER,
|
IMAGE_PLACEHOLDER,
|
||||||
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
|
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
|
||||||
@ -590,9 +615,6 @@ class InternVLPlugin(BasePlugin):
|
|||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
if num_video_tokens >= len(videos):
|
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
|
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
|
||||||
end_patch_index = video_patch_indices[num_video_tokens]
|
end_patch_index = video_patch_indices[num_video_tokens]
|
||||||
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
||||||
@ -605,12 +627,6 @@ class InternVLPlugin(BasePlugin):
|
|||||||
|
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
|
||||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -637,6 +653,7 @@ class KimiVLPlugin(BasePlugin):
|
|||||||
@override
|
@override
|
||||||
def process_messages(self, messages, images, videos, audios, processor):
|
def process_messages(self, messages, images, videos, audios, processor):
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
|
||||||
@ -648,9 +665,6 @@ class KimiVLPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
IMAGE_PLACEHOLDER,
|
IMAGE_PLACEHOLDER,
|
||||||
@ -661,9 +675,6 @@ class KimiVLPlugin(BasePlugin):
|
|||||||
|
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@ -679,6 +690,7 @@ class Llama4Plugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
if "pixel_values" in mm_inputs:
|
if "pixel_values" in mm_inputs:
|
||||||
@ -701,9 +713,6 @@ class Llama4Plugin(BasePlugin):
|
|||||||
for local_image_index, split_part in enumerate(prompt_splits):
|
for local_image_index, split_part in enumerate(prompt_splits):
|
||||||
new_content.append(split_part)
|
new_content.append(split_part)
|
||||||
if local_image_index < placeholder_count:
|
if local_image_index < placeholder_count:
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
tokens_for_this_image = processor._prompt_split_image(
|
tokens_for_this_image = processor._prompt_split_image(
|
||||||
aspect_ratios[num_image_tokens], num_patches_per_chunk
|
aspect_ratios[num_image_tokens], num_patches_per_chunk
|
||||||
)
|
)
|
||||||
@ -716,9 +725,6 @@ class Llama4Plugin(BasePlugin):
|
|||||||
|
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -751,7 +757,7 @@ class LlavaPlugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
num_image_tokens = 0
|
self._validate_messages(messages, images, videos, audios)
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
@ -768,17 +774,10 @@ class LlavaPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
num_image_tokens += 1
|
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", self.image_token)
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@ -794,6 +793,7 @@ class LlavaNextPlugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
num_image_tokens = 0
|
num_image_tokens = 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
@ -805,9 +805,6 @@ class LlavaNextPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
orig_height, orig_width = next(image_sizes)
|
orig_height, orig_width = next(image_sizes)
|
||||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
@ -821,9 +818,6 @@ class LlavaNextPlugin(BasePlugin):
|
|||||||
|
|
||||||
message["content"] = content.replace("{{image}}", self.image_token)
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@ -839,7 +833,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
num_image_tokens, num_video_tokens = 0, 0
|
self._validate_messages(messages, images, videos, audios)
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
@ -850,9 +844,6 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
orig_height, orig_width = next(image_sizes)
|
orig_height, orig_width = next(image_sizes)
|
||||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
@ -862,7 +853,6 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
image_seqlen = 1
|
image_seqlen = 1
|
||||||
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
num_image_tokens += 1
|
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", self.image_token)
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
@ -879,20 +869,10 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
if num_video_tokens >= len(videos):
|
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||||
num_video_tokens += 1
|
|
||||||
|
|
||||||
message["content"] = content.replace("{{video}}", self.video_token)
|
message["content"] = content.replace("{{video}}", self.video_token)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
|
||||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@ -978,6 +958,7 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||||
@ -996,24 +977,15 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
if num_video_tokens >= len(videos):
|
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
|
|
||||||
while AUDIO_PLACEHOLDER in content:
|
while AUDIO_PLACEHOLDER in content:
|
||||||
if num_audio_tokens >= len(audios):
|
|
||||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
||||||
num_audio_tokens += 1
|
num_audio_tokens += 1
|
||||||
|
|
||||||
@ -1065,15 +1037,6 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
final_text += text_chunks[-1]
|
final_text += text_chunks[-1]
|
||||||
messages[index]["content"] = final_text
|
messages[index]["content"] = final_text
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
|
||||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(audios) != num_audio_tokens:
|
|
||||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -1157,6 +1120,7 @@ class MllamaPlugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
num_image_tokens = 0
|
num_image_tokens = 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@ -1164,9 +1128,6 @@ class MllamaPlugin(BasePlugin):
|
|||||||
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
||||||
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -1214,6 +1175,7 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
num_image_tokens = 0
|
num_image_tokens = 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@ -1224,9 +1186,6 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
|
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -1281,7 +1240,7 @@ class PixtralPlugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
num_image_tokens = 0
|
self._validate_messages(messages, images, videos, audios)
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
@ -1291,15 +1250,13 @@ class PixtralPlugin(BasePlugin):
|
|||||||
image_sizes = iter(mm_inputs["image_sizes"][0])
|
image_sizes = iter(mm_inputs["image_sizes"][0])
|
||||||
else:
|
else:
|
||||||
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
||||||
|
|
||||||
image_break_token: str = getattr(processor, "image_break_token")
|
image_break_token: str = getattr(processor, "image_break_token")
|
||||||
image_end_token: str = getattr(processor, "image_end_token")
|
image_end_token: str = getattr(processor, "image_end_token")
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
height, width = next(image_sizes)
|
height, width = next(image_sizes)
|
||||||
num_height_tokens = height // processor.patch_size
|
num_height_tokens = height // processor.patch_size
|
||||||
@ -1312,13 +1269,9 @@ class PixtralPlugin(BasePlugin):
|
|||||||
replace_str = self.image_token
|
replace_str = self.image_token
|
||||||
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||||
num_image_tokens += 1
|
|
||||||
|
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -1355,9 +1308,9 @@ class Qwen2AudioPlugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
bos_token: str = getattr(processor, "audio_bos_token")
|
bos_token: str = getattr(processor, "audio_bos_token")
|
||||||
eos_token: str = getattr(processor, "audio_eos_token")
|
eos_token: str = getattr(processor, "audio_eos_token")
|
||||||
num_audio_tokens = 0
|
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
mm_inputs = self._get_mm_inputs([], [], audios, processor)
|
mm_inputs = self._get_mm_inputs([], [], audios, processor)
|
||||||
@ -1367,9 +1320,6 @@ class Qwen2AudioPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while AUDIO_PLACEHOLDER in content:
|
while AUDIO_PLACEHOLDER in content:
|
||||||
if num_audio_tokens >= len(audios):
|
|
||||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
audio_length = audio_lengths.pop(0)
|
audio_length = audio_lengths.pop(0)
|
||||||
input_length = (audio_length - 1) // 2 + 1
|
input_length = (audio_length - 1) // 2 + 1
|
||||||
@ -1380,13 +1330,9 @@ class Qwen2AudioPlugin(BasePlugin):
|
|||||||
content = content.replace(
|
content = content.replace(
|
||||||
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
|
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
|
||||||
)
|
)
|
||||||
num_audio_tokens += 1
|
|
||||||
|
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(audios) != num_audio_tokens:
|
|
||||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -1494,6 +1440,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
num_image_tokens, num_video_tokens = 0, 0
|
num_image_tokens, num_video_tokens = 0, 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||||
@ -1510,9 +1457,6 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
|
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
|
||||||
@ -1520,9 +1464,6 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
if num_video_tokens >= len(videos):
|
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
|
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
|
||||||
@ -1531,12 +1472,6 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
|
||||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@ -1602,6 +1537,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||||
@ -1624,9 +1560,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
|
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
|
||||||
@ -1642,11 +1575,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
if num_video_tokens >= len(videos):
|
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
if num_audio_tokens >= len(audios):
|
|
||||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
video_pos = content.find(VIDEO_PLACEHOLDER)
|
video_pos = content.find(VIDEO_PLACEHOLDER)
|
||||||
audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos)
|
audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos)
|
||||||
if audio_pos == -1 or audio_pos < video_pos:
|
if audio_pos == -1 or audio_pos < video_pos:
|
||||||
@ -1688,9 +1616,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
else:
|
else:
|
||||||
while AUDIO_PLACEHOLDER in content:
|
while AUDIO_PLACEHOLDER in content:
|
||||||
if num_audio_tokens >= len(audios):
|
|
||||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
|
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
|
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
|
||||||
@ -1698,9 +1623,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
num_audio_tokens += 1
|
num_audio_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
if num_video_tokens >= len(videos):
|
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
video_seqlen = (
|
video_seqlen = (
|
||||||
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
)
|
)
|
||||||
@ -1711,15 +1633,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
|
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(audios) != num_audio_tokens:
|
|
||||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
|
||||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@ -1735,6 +1648,7 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
num_image_tokens, num_video_tokens = 0, 0
|
num_image_tokens, num_video_tokens = 0, 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
num_frames = 0
|
num_frames = 0
|
||||||
@ -1762,28 +1676,16 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
if num_video_tokens >= len(videos):
|
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
|
|
||||||
content = content.replace("{{image}}", self.image_token)
|
content = content.replace("{{image}}", self.image_token)
|
||||||
message["content"] = content.replace("{{video}}", self.video_token)
|
message["content"] = content.replace("{{video}}", self.video_token)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
|
||||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user