mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
fix inputs
This commit is contained in:
@@ -9,7 +9,7 @@ from transformers.image_utils import get_image_size, to_numpy_array
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
@@ -21,6 +21,13 @@ if is_pyav_available():
|
||||
import av
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.45.0"):
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
convert_sparse_cross_attention_mask_to_dense,
|
||||
get_cross_attention_token_mask,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from av.stream import Stream
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
@@ -75,8 +82,8 @@ class BasePlugin:
|
||||
Pre-processes a single image.
|
||||
"""
|
||||
image_resolution: int = kwargs.get("image_resolution")
|
||||
if max(image.width, image.height) > image_resolution:
|
||||
resize_factor = image_resolution / max(image.width, image.height)
|
||||
if image.width * image.height > image_resolution:
|
||||
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
image = image.resize((width, height), resample=Image.NEAREST)
|
||||
|
||||
@@ -165,15 +172,15 @@ class BasePlugin:
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_resolution=getattr(processor, "image_resolution", 512),
|
||||
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||
)
|
||||
input_dict["images"] = images
|
||||
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_resolution=getattr(processor, "video_resolution", 128),
|
||||
video_fps=getattr(processor, "video_fps", 1.0),
|
||||
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||
)
|
||||
input_dict["videos"] = videos
|
||||
@@ -223,7 +230,7 @@ class BasePlugin:
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
r"""
|
||||
@@ -234,7 +241,7 @@ class BasePlugin:
|
||||
videos: a list of video inputs, shape (num_videos,)
|
||||
imglens: number of images in each sample, shape (batch_size,)
|
||||
vidlens: number of videos in each sample, shape (batch_size,)
|
||||
seqlens: number of tokens in each sample, shape (batch_size,)
|
||||
batch_ids: input ids of samples, shape (batch_size, seq_len)
|
||||
processor: a processor for pre-processing images and videos
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
@@ -258,12 +265,12 @@ class LlavaPlugin(BasePlugin):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@@ -274,7 +281,7 @@ class LlavaPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
@@ -296,23 +303,27 @@ class LlavaNextPlugin(BasePlugin):
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "image_sizes" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
|
||||
if "pixel_values" in mm_inputs:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
@@ -322,12 +333,11 @@ class LlavaNextPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
res = self._get_mm_inputs(images, videos, processor)
|
||||
return res
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class LlavaNextVideoPlugin(BasePlugin):
|
||||
@@ -340,8 +350,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
@@ -349,15 +358,15 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
|
||||
while self.image_token in content:
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
@@ -367,19 +376,19 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.video_token in content:
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(self.video_token, "{{video}}", 1)
|
||||
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||
|
||||
message["content"] = content.replace("{{video}}", self.video_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@@ -390,7 +399,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
@@ -418,7 +427,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
message["content"] = content.replace("{{image}}", "")
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@@ -449,10 +458,11 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
seqlens = [len(input_ids) for input_ids in batch_ids]
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||
return mm_inputs
|
||||
@@ -481,7 +491,7 @@ class PixtralPlugin(BasePlugin):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if image_input_sizes is None:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
raise ValueError("Cannot get image input sizes.")
|
||||
|
||||
image_size = image_input_sizes[0][num_image_tokens]
|
||||
height, width = image_size
|
||||
@@ -497,7 +507,7 @@ class PixtralPlugin(BasePlugin):
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@@ -508,7 +518,7 @@ class PixtralPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
@@ -592,10 +602,10 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens")
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@@ -606,7 +616,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
@@ -623,42 +633,45 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
num_frames = 0
|
||||
exist_images = "pixel_values_images" in mm_inputs
|
||||
exist_videos = "pixel_values_videos" in mm_inputs
|
||||
if exist_videos or exist_images:
|
||||
if exist_images:
|
||||
has_images = "pixel_values_images" in mm_inputs
|
||||
has_videos = "pixel_values_videos" in mm_inputs
|
||||
if has_images or has_videos:
|
||||
if has_images:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
if exist_videos:
|
||||
|
||||
if has_videos:
|
||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(pixel_values_video[0])
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||
video_seqlen = image_seqlen * num_frames
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}", 1)
|
||||
while self.video_token in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(self.video_token, "{{video}}", 1)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
content = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||
|
||||
content = content.replace("{{image}}", self.image_token)
|
||||
message["content"] = content.replace("{{video}}", self.video_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {self.image_token} tokens")
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens")
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@@ -669,7 +682,7 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
@@ -686,30 +699,67 @@ class MllamaPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
|
||||
message["content"] = content
|
||||
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
||||
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||
|
||||
Returns:
|
||||
pixel_values: tensor with shape
|
||||
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
||||
For example, (2, 1, 4, 3, 560, 560).
|
||||
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
||||
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
||||
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
|
||||
return image_processor([[image] for image in images], return_tensors="pt")
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._get_mm_inputs(images, videos, processor)
|
||||
if images is not None:
|
||||
images = [Image.open(image) if isinstance(image, str) else image for image in images]
|
||||
image_features = processor.image_processor(images)
|
||||
_ = image_features.pop("num_tiles")
|
||||
image_features = {k: v if isinstance(v, torch.Tensor) else torch.tensor(v) for k, v in image_features.items()}
|
||||
return image_features
|
||||
self._validate_input(images, videos)
|
||||
if len(images) != len(batch_ids):
|
||||
raise ValueError("Mllama only supports one image per sample.")
|
||||
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
num_tiles = mm_inputs.pop("num_tiles")
|
||||
image_token_id = getattr(processor, "image_token_id")
|
||||
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
||||
cross_attention_token_mask = [
|
||||
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
||||
]
|
||||
mm_inputs["cross_attention_mask"] = convert_sparse_cross_attention_mask_to_dense(
|
||||
cross_attention_token_mask,
|
||||
num_tiles=num_tiles,
|
||||
max_num_tiles=max_image_tiles,
|
||||
length=max(len(input_ids) for input_ids in batch_ids),
|
||||
)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
|
||||
Reference in New Issue
Block a user