[breaking] bump transformers to 4.45.0 & improve ci (#7746)

* update ci

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
hoshi-hiyouga
2025-04-17 02:36:48 +08:00
committed by GitHub
parent d222f63cb7
commit 86ebb219d6
23 changed files with 211 additions and 140 deletions

View File

@@ -25,12 +25,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import numpy as np
import torch
from transformers.image_utils import (
get_image_size,
make_batched_videos,
make_flat_list_of_images,
to_numpy_array,
)
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@@ -62,6 +57,10 @@ if is_transformers_version_greater_than("4.45.0"):
)
if is_transformers_version_greater_than("4.49.0"):
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
if TYPE_CHECKING:
from av.stream import Stream
from numpy.typing import NDArray
@@ -487,61 +486,6 @@ class Gemma3Plugin(BasePlugin):
@dataclass
class InternVLPlugin(BasePlugin):
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
num_video_tokens = 0
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_pixel_patch_list = mm_inputs.get("image_num_patches", None) # pathes of images
video_num_patches = mm_inputs.get("video_num_patches", None) # all patches for frames of videos
video_patch_indices = mm_inputs.get("video_patch_indices", None) # num frames of per video
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_pixel_patch_list):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(
IMAGE_PLACEHOLDER,
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
1,
)
num_image_tokens += 1
message["content"] = content
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_patch_indices):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
end_patch_index = video_patch_indices[num_video_tokens]
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
video_replaced_prompt = "\n".join(
f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seqlen * num_patches[i]}</img>"
for i in range(len(num_patches))
)
content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
num_video_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@override
def _get_mm_inputs(
self,
@@ -621,6 +565,63 @@ class InternVLPlugin(BasePlugin):
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
num_video_tokens = 0
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_pixel_patch_list = mm_inputs.get("image_num_patches") # pathes of images
video_num_patches = mm_inputs.get("video_num_patches") # all patches for frames of videos
video_patch_indices = mm_inputs.get("video_patch_indices") # num frames of per video
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_pixel_patch_list):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(
IMAGE_PLACEHOLDER,
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
1,
)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_patch_indices):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
end_patch_index = video_patch_indices[num_video_tokens]
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
video_replaced_prompt = "\n".join(
f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seqlen * num_patches[i]}</img>"
for i in range(len(num_patches))
)
content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
num_video_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
@@ -634,12 +635,10 @@ class InternVLPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("image_num_patches", None)
mm_inputs.pop("video_patch_indices", None)
mm_inputs.pop("video_num_patches", None)
return mm_inputs