mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
[breaking] bump transformers to 4.45.0 & improve ci (#7746)
* update ci * fix * fix * fix * fix * fix
This commit is contained in:
@@ -25,12 +25,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.image_utils import (
|
||||
get_image_size,
|
||||
make_batched_videos,
|
||||
make_flat_list_of_images,
|
||||
to_numpy_array,
|
||||
)
|
||||
from transformers.image_utils import get_image_size, to_numpy_array
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
@@ -62,6 +57,10 @@ if is_transformers_version_greater_than("4.45.0"):
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.49.0"):
|
||||
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from av.stream import Stream
|
||||
from numpy.typing import NDArray
|
||||
@@ -487,61 +486,6 @@ class Gemma3Plugin(BasePlugin):
|
||||
|
||||
@dataclass
|
||||
class InternVLPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
image_pixel_patch_list = mm_inputs.get("image_num_patches", None) # pathes of images
|
||||
video_num_patches = mm_inputs.get("video_num_patches", None) # all patches for frames of videos
|
||||
video_patch_indices = mm_inputs.get("video_patch_indices", None) # num frames of per video
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_pixel_patch_list):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
message["content"] = content
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_patch_indices):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
|
||||
end_patch_index = video_patch_indices[num_video_tokens]
|
||||
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
||||
video_replaced_prompt = "\n".join(
|
||||
f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seqlen * num_patches[i]}</img>"
|
||||
for i in range(len(num_patches))
|
||||
)
|
||||
content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
|
||||
num_video_tokens += 1
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
@@ -621,6 +565,63 @@ class InternVLPlugin(BasePlugin):
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
image_pixel_patch_list = mm_inputs.get("image_num_patches") # pathes of images
|
||||
video_num_patches = mm_inputs.get("video_num_patches") # all patches for frames of videos
|
||||
video_patch_indices = mm_inputs.get("video_patch_indices") # num frames of per video
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_pixel_patch_list):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_patch_indices):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
|
||||
end_patch_index = video_patch_indices[num_video_tokens]
|
||||
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
||||
video_replaced_prompt = "\n".join(
|
||||
f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seqlen * num_patches[i]}</img>"
|
||||
for i in range(len(num_patches))
|
||||
)
|
||||
content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
@@ -634,12 +635,10 @@ class InternVLPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs.pop("image_num_patches", None)
|
||||
mm_inputs.pop("video_patch_indices", None)
|
||||
mm_inputs.pop("video_num_patches", None)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user