fix style

Former-commit-id: 6ddea0f3d3ef568378470ce967a0e8d02eeac5dd
This commit is contained in:
BUAADreamer 2024-09-29 20:30:57 +08:00
parent 96d51325ad
commit 8ee588248e

View File

@ -291,7 +291,7 @@ class LlavaNextPlugin(BasePlugin):
num_image_tokens += 1 num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
message['content'] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
@ -341,7 +341,7 @@ class LlavaNextVideoPlugin(BasePlugin):
num_image_tokens += 1 num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
message['content'] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs: if "pixel_values_videos" in mm_inputs:
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
@ -355,7 +355,7 @@ class LlavaNextVideoPlugin(BasePlugin):
while self.video_token in content: while self.video_token in content:
num_video_tokens += 1 num_video_tokens += 1
content = content.replace(self.video_token, "{{video}}", 1) content = content.replace(self.video_token, "{{video}}", 1)
message['content'] = content.replace("{{video}}", self.video_token * video_seqlen) message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))