mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
[data] improve mm plugin (#7910)
This commit is contained in:
parent
ae392e054c
commit
77c569e071
@ -167,16 +167,45 @@ class MMPluginMixin:
|
||||
)
|
||||
|
||||
if self.image_token is not None and processor is None:
|
||||
raise ValueError("Processor was not found, please check and update your processor config.")
|
||||
raise ValueError("Processor was not found, please check and update your model file.")
|
||||
|
||||
if self.image_token is not None and image_processor is None:
|
||||
raise ValueError("Image processor was not found, please check and update your processor config.")
|
||||
raise ValueError("Image processor was not found, please check and update your model file.")
|
||||
|
||||
if self.video_token is not None and video_processor is None:
|
||||
raise ValueError("Video processor was not found, please check and update your processor config.")
|
||||
raise ValueError("Video processor was not found, please check and update your model file.")
|
||||
|
||||
if self.audio_token is not None and feature_extractor is None:
|
||||
raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
|
||||
raise ValueError("Audio feature extractor was not found, please check and update your model file.")
|
||||
|
||||
def _validate_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
):
|
||||
r"""Validate if the number of images, videos and audios match the number of placeholders in messages."""
|
||||
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||
for message in messages:
|
||||
num_image_tokens += message["content"].count(IMAGE_PLACEHOLDER)
|
||||
num_video_tokens += message["content"].count(VIDEO_PLACEHOLDER)
|
||||
num_audio_tokens += message["content"].count(AUDIO_PLACEHOLDER)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(
|
||||
f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens in {messages}."
|
||||
)
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(
|
||||
f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens in {messages}."
|
||||
)
|
||||
|
||||
if len(audios) != num_audio_tokens:
|
||||
raise ValueError(
|
||||
f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens in {messages}."
|
||||
)
|
||||
|
||||
def _preprocess_image(
|
||||
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
|
||||
@ -364,6 +393,7 @@ class BasePlugin(MMPluginMixin):
|
||||
) -> list[dict[str, str]]:
|
||||
r"""Pre-process input messages before tokenization for VLMs."""
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
return messages
|
||||
|
||||
def process_token_ids(
|
||||
@ -420,6 +450,7 @@ class Gemma3Plugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
boi_token: str = getattr(processor, "boi_token")
|
||||
@ -446,9 +477,6 @@ class Gemma3Plugin(BasePlugin):
|
||||
|
||||
message["content"] = content.replace("{{image}}", image_str)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
@ -566,8 +594,8 @@ class InternVLPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
@ -579,9 +607,6 @@ class InternVLPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
|
||||
@ -590,9 +615,6 @@ class InternVLPlugin(BasePlugin):
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
|
||||
end_patch_index = video_patch_indices[num_video_tokens]
|
||||
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
||||
@ -605,12 +627,6 @@ class InternVLPlugin(BasePlugin):
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
@ -637,6 +653,7 @@ class KimiVLPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(self, messages, images, videos, audios, processor):
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
@ -648,9 +665,6 @@ class KimiVLPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
@ -661,9 +675,6 @@ class KimiVLPlugin(BasePlugin):
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@ -679,6 +690,7 @@ class Llama4Plugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
@ -701,9 +713,6 @@ class Llama4Plugin(BasePlugin):
|
||||
for local_image_index, split_part in enumerate(prompt_splits):
|
||||
new_content.append(split_part)
|
||||
if local_image_index < placeholder_count:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
tokens_for_this_image = processor._prompt_split_image(
|
||||
aspect_ratios[num_image_tokens], num_patches_per_chunk
|
||||
)
|
||||
@ -716,9 +725,6 @@ class Llama4Plugin(BasePlugin):
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
@ -751,7 +757,7 @@ class LlavaPlugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
@ -768,17 +774,10 @@ class LlavaPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@ -794,6 +793,7 @@ class LlavaNextPlugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
if self.expand_mm_tokens:
|
||||
@ -805,9 +805,6 @@ class LlavaNextPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if self.expand_mm_tokens:
|
||||
orig_height, orig_width = next(image_sizes)
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
@ -821,9 +818,6 @@ class LlavaNextPlugin(BasePlugin):
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@ -839,7 +833,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
@ -850,9 +844,6 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if self.expand_mm_tokens:
|
||||
orig_height, orig_width = next(image_sizes)
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
@ -862,7 +853,6 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
image_seqlen = 1
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
@ -879,20 +869,10 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{video}}", self.video_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@ -978,6 +958,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
@ -996,24 +977,15 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
for i, message in enumerate(messages):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
if num_audio_tokens >= len(audios):
|
||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
||||
num_audio_tokens += 1
|
||||
|
||||
@ -1065,15 +1037,6 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
final_text += text_chunks[-1]
|
||||
messages[index]["content"] = final_text
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(audios) != num_audio_tokens:
|
||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
@ -1157,6 +1120,7 @@ class MllamaPlugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
@ -1164,9 +1128,6 @@ class MllamaPlugin(BasePlugin):
|
||||
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
||||
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
@ -1214,6 +1175,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
@ -1224,9 +1186,6 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
@ -1281,7 +1240,7 @@ class PixtralPlugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
@ -1291,15 +1250,13 @@ class PixtralPlugin(BasePlugin):
|
||||
image_sizes = iter(mm_inputs["image_sizes"][0])
|
||||
else:
|
||||
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
||||
|
||||
image_break_token: str = getattr(processor, "image_break_token")
|
||||
image_end_token: str = getattr(processor, "image_end_token")
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if self.expand_mm_tokens:
|
||||
height, width = next(image_sizes)
|
||||
num_height_tokens = height // processor.patch_size
|
||||
@ -1312,13 +1269,9 @@ class PixtralPlugin(BasePlugin):
|
||||
replace_str = self.image_token
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
@ -1355,9 +1308,9 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
bos_token: str = getattr(processor, "audio_bos_token")
|
||||
eos_token: str = getattr(processor, "audio_eos_token")
|
||||
num_audio_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs([], [], audios, processor)
|
||||
@ -1367,9 +1320,6 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
if num_audio_tokens >= len(audios):
|
||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
if self.expand_mm_tokens:
|
||||
audio_length = audio_lengths.pop(0)
|
||||
input_length = (audio_length - 1) // 2 + 1
|
||||
@ -1380,13 +1330,9 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
content = content.replace(
|
||||
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
|
||||
)
|
||||
num_audio_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(audios) != num_audio_tokens:
|
||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
@ -1494,6 +1440,7 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
@ -1510,9 +1457,6 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
|
||||
@ -1520,9 +1464,6 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
|
||||
@ -1531,12 +1472,6 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@ -1602,6 +1537,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
@ -1624,9 +1560,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
|
||||
@ -1642,11 +1575,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
)
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
if num_audio_tokens >= len(audios):
|
||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_pos = content.find(VIDEO_PLACEHOLDER)
|
||||
audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos)
|
||||
if audio_pos == -1 or audio_pos < video_pos:
|
||||
@ -1688,9 +1616,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
num_video_tokens += 1
|
||||
else:
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
if num_audio_tokens >= len(audios):
|
||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
|
||||
@ -1698,9 +1623,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
num_audio_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_seqlen = (
|
||||
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
)
|
||||
@ -1711,15 +1633,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(audios) != num_audio_tokens:
|
||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@ -1735,6 +1648,7 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
num_frames = 0
|
||||
@ -1762,28 +1676,16 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(images):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
content = content.replace("{{image}}", self.image_token)
|
||||
message["content"] = content.replace("{{video}}", self.video_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user