diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py
index e9df7e33..b41e84fd 100644
--- a/src/llamafactory/data/mm_plugin.py
+++ b/src/llamafactory/data/mm_plugin.py
@@ -167,16 +167,45 @@ class MMPluginMixin:
)
if self.image_token is not None and processor is None:
- raise ValueError("Processor was not found, please check and update your processor config.")
+ raise ValueError("Processor was not found, please check and update your model file.")
if self.image_token is not None and image_processor is None:
- raise ValueError("Image processor was not found, please check and update your processor config.")
+ raise ValueError("Image processor was not found, please check and update your model file.")
if self.video_token is not None and video_processor is None:
- raise ValueError("Video processor was not found, please check and update your processor config.")
+ raise ValueError("Video processor was not found, please check and update your model file.")
if self.audio_token is not None and feature_extractor is None:
- raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
+ raise ValueError("Audio feature extractor was not found, please check and update your model file.")
+
+ def _validate_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ ):
+ r"""Validate if the number of images, videos and audios match the number of placeholders in messages."""
+ num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
+ for message in messages:
+ num_image_tokens += message["content"].count(IMAGE_PLACEHOLDER)
+ num_video_tokens += message["content"].count(VIDEO_PLACEHOLDER)
+ num_audio_tokens += message["content"].count(AUDIO_PLACEHOLDER)
+
+ if len(images) != num_image_tokens:
+ raise ValueError(
+ f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens in {messages}."
+ )
+
+ if len(videos) != num_video_tokens:
+ raise ValueError(
+ f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens in {messages}."
+ )
+
+ if len(audios) != num_audio_tokens:
+ raise ValueError(
+ f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens in {messages}."
+ )
def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
@@ -364,6 +393,7 @@ class BasePlugin(MMPluginMixin):
) -> list[dict[str, str]]:
r"""Pre-process input messages before tokenization for VLMs."""
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
return messages
def process_token_ids(
@@ -420,6 +450,7 @@ class Gemma3Plugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
boi_token: str = getattr(processor, "boi_token")
@@ -446,9 +477,6 @@ class Gemma3Plugin(BasePlugin):
message["content"] = content.replace("{{image}}", image_str)
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
return messages
@override
@@ -566,8 +594,8 @@ class InternVLPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
- num_image_tokens = 0
- num_video_tokens = 0
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens, num_video_tokens = 0, 0
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@@ -579,9 +607,6 @@ class InternVLPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
content = content.replace(
IMAGE_PLACEHOLDER,
f"
{'' * image_seqlen * image_pixel_patch_list[num_image_tokens]}",
@@ -590,9 +615,6 @@ class InternVLPlugin(BasePlugin):
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
- if num_video_tokens >= len(videos):
- raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
-
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
end_patch_index = video_patch_indices[num_video_tokens]
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
@@ -605,12 +627,6 @@ class InternVLPlugin(BasePlugin):
message["content"] = content
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
- if len(videos) != num_video_tokens:
- raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
-
return messages
@override
@@ -637,6 +653,7 @@ class KimiVLPlugin(BasePlugin):
@override
def process_messages(self, messages, images, videos, audios, processor):
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@@ -648,9 +665,6 @@ class KimiVLPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
@@ -661,9 +675,6 @@ class KimiVLPlugin(BasePlugin):
message["content"] = content
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
return messages
@@ -679,6 +690,7 @@ class Llama4Plugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
@@ -701,9 +713,6 @@ class Llama4Plugin(BasePlugin):
for local_image_index, split_part in enumerate(prompt_splits):
new_content.append(split_part)
if local_image_index < placeholder_count:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
tokens_for_this_image = processor._prompt_split_image(
aspect_ratios[num_image_tokens], num_patches_per_chunk
)
@@ -716,9 +725,6 @@ class Llama4Plugin(BasePlugin):
message["content"] = content
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
return messages
@override
@@ -751,7 +757,7 @@ class LlavaPlugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
- num_image_tokens = 0
+ self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@@ -768,17 +774,10 @@ class LlavaPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
- num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
return messages
@@ -794,6 +793,7 @@ class LlavaNextPlugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
@@ -805,9 +805,6 @@ class LlavaNextPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
@@ -821,9 +818,6 @@ class LlavaNextPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token)
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
return messages
@@ -839,7 +833,7 @@ class LlavaNextVideoPlugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
- num_image_tokens, num_video_tokens = 0, 0
+ self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@@ -850,9 +844,6 @@ class LlavaNextVideoPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
@@ -862,7 +853,6 @@ class LlavaNextVideoPlugin(BasePlugin):
image_seqlen = 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
- num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
@@ -879,20 +869,10 @@ class LlavaNextVideoPlugin(BasePlugin):
for message in messages:
content = message["content"]
while VIDEO_PLACEHOLDER in content:
- if num_video_tokens >= len(videos):
- raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
-
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
- num_video_tokens += 1
message["content"] = content.replace("{{video}}", self.video_token)
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
- if len(videos) != num_video_tokens:
- raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
-
return messages
@@ -978,6 +958,7 @@ class MiniCPMVPlugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
@@ -996,24 +977,15 @@ class MiniCPMVPlugin(BasePlugin):
for i, message in enumerate(messages):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
- if num_video_tokens >= len(videos):
- raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
-
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1
while AUDIO_PLACEHOLDER in content:
- if num_audio_tokens >= len(audios):
- raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
-
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
num_audio_tokens += 1
@@ -1065,15 +1037,6 @@ class MiniCPMVPlugin(BasePlugin):
final_text += text_chunks[-1]
messages[index]["content"] = final_text
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
- if len(videos) != num_video_tokens:
- raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
-
- if len(audios) != num_audio_tokens:
- raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
-
return messages
@override
@@ -1157,6 +1120,7 @@ class MllamaPlugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
@@ -1164,9 +1128,6 @@ class MllamaPlugin(BasePlugin):
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
return messages
@override
@@ -1214,6 +1175,7 @@ class PaliGemmaPlugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
@@ -1224,9 +1186,6 @@ class PaliGemmaPlugin(BasePlugin):
message["content"] = content
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
return messages
@override
@@ -1281,7 +1240,7 @@ class PixtralPlugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
- num_image_tokens = 0
+ self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@@ -1291,15 +1250,13 @@ class PixtralPlugin(BasePlugin):
image_sizes = iter(mm_inputs["image_sizes"][0])
else:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
+
image_break_token: str = getattr(processor, "image_break_token")
image_end_token: str = getattr(processor, "image_end_token")
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
if self.expand_mm_tokens:
height, width = next(image_sizes)
num_height_tokens = height // processor.patch_size
@@ -1312,13 +1269,9 @@ class PixtralPlugin(BasePlugin):
replace_str = self.image_token
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
- num_image_tokens += 1
message["content"] = content
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
return messages
@override
@@ -1355,9 +1308,9 @@ class Qwen2AudioPlugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token")
- num_audio_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs([], [], audios, processor)
@@ -1367,9 +1320,6 @@ class Qwen2AudioPlugin(BasePlugin):
for message in messages:
content = message["content"]
while AUDIO_PLACEHOLDER in content:
- if num_audio_tokens >= len(audios):
- raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
-
if self.expand_mm_tokens:
audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1
@@ -1380,13 +1330,9 @@ class Qwen2AudioPlugin(BasePlugin):
content = content.replace(
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
)
- num_audio_tokens += 1
message["content"] = content
- if len(audios) != num_audio_tokens:
- raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
-
return messages
@override
@@ -1494,6 +1440,7 @@ class Qwen2VLPlugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
@@ -1510,9 +1457,6 @@ class Qwen2VLPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
@@ -1520,9 +1464,6 @@ class Qwen2VLPlugin(BasePlugin):
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
- if num_video_tokens >= len(videos):
- raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
-
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
@@ -1531,12 +1472,6 @@ class Qwen2VLPlugin(BasePlugin):
message["content"] = content
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
- if len(videos) != num_video_tokens:
- raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
-
return messages
@@ -1602,6 +1537,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
@@ -1624,9 +1560,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
@@ -1642,11 +1575,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
)
while VIDEO_PLACEHOLDER in content:
- if num_video_tokens >= len(videos):
- raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
- if num_audio_tokens >= len(audios):
- raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
-
video_pos = content.find(VIDEO_PLACEHOLDER)
audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos)
if audio_pos == -1 or audio_pos < video_pos:
@@ -1688,9 +1616,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
num_video_tokens += 1
else:
while AUDIO_PLACEHOLDER in content:
- if num_audio_tokens >= len(audios):
- raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
-
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
content = content.replace(
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
@@ -1698,9 +1623,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
num_audio_tokens += 1
while VIDEO_PLACEHOLDER in content:
- if num_video_tokens >= len(videos):
- raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
-
video_seqlen = (
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
)
@@ -1711,15 +1633,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
message["content"] = content
- if len(audios) != num_audio_tokens:
- raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
-
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
- if len(videos) != num_video_tokens:
- raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
-
return messages
@@ -1735,6 +1648,7 @@ class VideoLlavaPlugin(BasePlugin):
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
num_frames = 0
@@ -1762,28 +1676,16 @@ class VideoLlavaPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(images):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
-
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
- if num_video_tokens >= len(videos):
- raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
-
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1
content = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{video}}", self.video_token)
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
- if len(videos) != num_video_tokens:
- raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
-
return messages