fix qwen2vl preprocess

Former-commit-id: c93795ae14b1d5a0a3440d18f8197fd53cd013da
This commit is contained in:
hiyouga 2024-09-09 22:33:33 +08:00
parent 3aefdad4ec
commit c52eeb70e7
2 changed files with 143 additions and 110 deletions

View File

@ -1,3 +1,4 @@
import math
from copy import deepcopy from copy import deepcopy
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
@ -11,7 +12,6 @@ from ..extras.packages import is_pillow_available, is_pyav_available
if is_pillow_available(): if is_pillow_available():
from PIL import Image from PIL import Image
from PIL.Image import Image as ImageObject
if is_pyav_available(): if is_pyav_available():
@ -20,6 +20,8 @@ if is_pyav_available():
if TYPE_CHECKING: if TYPE_CHECKING:
import torch import torch
from av.stream import Stream
from PIL.Image import Image as ImageObject
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
@ -31,107 +33,6 @@ if TYPE_CHECKING:
VideoInput = str VideoInput = str
def _regularize_images(
images: Sequence["ImageInput"],
processor: "ProcessorMixin",
max_resolution: Optional[int] = None,
) -> List["ImageObject"]:
r"""
Regularizes images to avoid error. Including reading, resizing and converting.
"""
if max_resolution is None:
max_resolution: int = getattr(processor, "image_resolution", 512)
results = []
for image in images:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, dict):
if image["bytes"] is not None:
image = Image.open(BytesIO(image["bytes"]))
else:
image = Image.open(image["path"])
if not isinstance(image, ImageObject):
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
if max(image.width, image.height) > max_resolution:
factor = max_resolution / max(image.width, image.height)
image = image.resize((int(image.width * factor), int(image.height * factor)), resample=Image.NEAREST)
if image.mode != "RGB":
image = image.convert("RGB")
results.append(image)
return results
def _regularize_videos(
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> List[List["ImageObject"]]:
r"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
video_resolution: int = getattr(processor, "video_resolution", 128)
video_fps: float = getattr(processor, "video_fps", 1.0)
video_maxlen: int = getattr(processor, "video_maxlen", 64)
video_factor: int = getattr(processor, "video_factor", 1)
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
sample_frames = min(video_maxlen, sample_frames) # reduce length <= maxlen
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
frames = _regularize_images(frames, processor, video_resolution)
results.append(frames)
return results
def _get_mm_inputs(
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs.
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
input_dict = {"images": None} # default key
if len(images) != 0:
images = _regularize_images(images, processor)
input_dict["images"] = images
if len(videos) != 0:
videos = _regularize_videos(videos, processor)
input_dict["videos"] = videos
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
return image_processor(**input_dict, return_tensors="pt")
else:
return {}
def _get_paligemma_token_type_ids( def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin" imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
) -> List[List[int]]: ) -> List[List[int]]:
@ -159,12 +60,125 @@ class BasePlugin:
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
) -> None: ) -> None:
r"""
Validates if this model accepts the input modalities.
"""
if len(images) != 0 and self.image_token is None: if len(images) != 0 and self.image_token is None:
raise ValueError("This model does not support image input.") raise ValueError("This model does not support image input.")
if len(videos) != 0 and self.video_token is None: if len(videos) != 0 and self.video_token is None:
raise ValueError("This model does not support video input.") raise ValueError("This model does not support video input.")
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
r"""
Pre-processes a single image.
"""
image_resolution: int = kwargs.get("image_resolution")
if max(image.width, image.height) > image_resolution:
resize_factor = image_resolution / max(image.width, image.height)
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.NEAREST)
if image.mode != "RGB":
image = image.convert("RGB")
return image
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
r"""
Computes video sample frames according to fps.
"""
video_fps: float = kwargs.get("video_fps")
video_maxlen: int = kwargs.get("video_maxlen")
total_frames = video_stream.frames
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
sample_frames = min(total_frames, video_maxlen, sample_frames)
return math.floor(sample_frames)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
r"""
Regularizes images to avoid error. Including reading and pre-processing.
"""
results = []
for image in images:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, dict):
if image["bytes"] is not None:
image = Image.open(BytesIO(image["bytes"]))
else:
image = Image.open(image["path"])
if not isinstance(image, ImageObject):
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
results.append(self._preprocess_image(image, **kwargs))
return results
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
r"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
frames = self._regularize_images(frames, **kwargs)
results.append(frames)
return results
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs.
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
input_dict = {"images": None} # default key
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512),
)
input_dict["images"] = images
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128),
video_fps=getattr(processor, "video_fps", 1.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
)
input_dict["videos"] = videos
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
return image_processor(**input_dict, return_tensors="pt")
else:
return {}
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@ -246,7 +260,7 @@ class LlavaPlugin(BasePlugin):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor) return self._get_mm_inputs(images, videos, processor)
class PaliGemmaPlugin(BasePlugin): class PaliGemmaPlugin(BasePlugin):
@ -305,12 +319,35 @@ class PaliGemmaPlugin(BasePlugin):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(images, videos)
mm_inputs = _get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, processor)
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs return mm_inputs
class Qwen2vlPlugin(BasePlugin): class Qwen2vlPlugin(BasePlugin):
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
image = super()._preprocess_image(image, **kwargs)
if min(image.width, image.height) < 28:
width, height = max(image.width, 28), max(image.height, 28)
image = image.resize((width, height), resample=Image.NEAREST)
if image.width / image.height > 200:
width, height = image.height * 180, image.height
image = image.resize((width, height), resample=Image.NEAREST)
if image.height / image.width > 200:
width, height = image.width, image.width * 180
image = image.resize((width, height), resample=Image.NEAREST)
return image
@override
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
sample_frames = sample_frames // 2 * 2
return sample_frames
@override @override
def process_messages( def process_messages(
self, self,
@ -322,7 +359,7 @@ class Qwen2vlPlugin(BasePlugin):
self._validate_input(images, videos) self._validate_input(images, videos)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2 merge_length: int = getattr(image_processor, "merge_size") ** 2
mm_inputs = _get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", []) image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", []) video_grid_thw = mm_inputs.get("video_grid_thw", [])
@ -377,7 +414,7 @@ class Qwen2vlPlugin(BasePlugin):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor) return self._get_mm_inputs(images, videos, processor)
PLUGINS = { PLUGINS = {

View File

@ -103,10 +103,6 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
setattr(processor, "video_resolution", model_args.video_resolution) setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen) setattr(processor, "video_maxlen", model_args.video_maxlen)
if getattr(config, "model_type", None) == "qwen2_vl":
setattr(processor, "video_factor", 2)
else:
setattr(processor, "video_factor", 1)
except Exception: except Exception:
processor = None processor = None