From c52eeb70e759fc6b65a62e05d8290d815a9ca9f5 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 9 Sep 2024 22:33:33 +0800 Subject: [PATCH] fix qwen2vl preprocess Former-commit-id: c93795ae14b1d5a0a3440d18f8197fd53cd013da --- src/llamafactory/data/mm_plugin.py | 249 +++++++++++++++++------------ src/llamafactory/model/loader.py | 4 - 2 files changed, 143 insertions(+), 110 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index c109d26e..f180c93b 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1,3 +1,4 @@ +import math from copy import deepcopy from io import BytesIO from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union @@ -11,7 +12,6 @@ from ..extras.packages import is_pillow_available, is_pyav_available if is_pillow_available(): from PIL import Image - from PIL.Image import Image as ImageObject if is_pyav_available(): @@ -20,6 +20,8 @@ if is_pyav_available(): if TYPE_CHECKING: import torch + from av.stream import Stream + from PIL.Image import Image as ImageObject from transformers import PreTrainedTokenizer, ProcessorMixin from transformers.image_processing_utils import BaseImageProcessor @@ -31,107 +33,6 @@ if TYPE_CHECKING: VideoInput = str -def _regularize_images( - images: Sequence["ImageInput"], - processor: "ProcessorMixin", - max_resolution: Optional[int] = None, -) -> List["ImageObject"]: - r""" - Regularizes images to avoid error. Including reading, resizing and converting. - """ - if max_resolution is None: - max_resolution: int = getattr(processor, "image_resolution", 512) - - results = [] - for image in images: - if isinstance(image, str): - image = Image.open(image) - elif isinstance(image, dict): - if image["bytes"] is not None: - image = Image.open(BytesIO(image["bytes"])) - else: - image = Image.open(image["path"]) - - if not isinstance(image, ImageObject): - raise ValueError("Expect input is a list of Images, but got {}.".format(type(image))) - - if max(image.width, image.height) > max_resolution: - factor = max_resolution / max(image.width, image.height) - image = image.resize((int(image.width * factor), int(image.height * factor)), resample=Image.NEAREST) - - if image.mode != "RGB": - image = image.convert("RGB") - - results.append(image) - - return results - - -def _regularize_videos( - videos: Sequence["VideoInput"], - processor: "ProcessorMixin", -) -> List[List["ImageObject"]]: - r""" - Regularizes videos to avoid error. Including reading, resizing and converting. - """ - video_resolution: int = getattr(processor, "video_resolution", 128) - video_fps: float = getattr(processor, "video_fps", 1.0) - video_maxlen: int = getattr(processor, "video_maxlen", 64) - video_factor: int = getattr(processor, "video_factor", 1) - results = [] - for video in videos: - container = av.open(video, "r") - video_stream = next(stream for stream in container.streams if stream.type == "video") - total_frames = video_stream.frames - sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps - sample_frames = min(video_maxlen, sample_frames) # reduce length <= maxlen - sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl - sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) - frames: List["ImageObject"] = [] - container.seek(0) - for frame_idx, frame in enumerate(container.decode(video_stream)): - if frame_idx in sample_indices: - frames.append(frame.to_image()) - - frames = _regularize_images(frames, processor, video_resolution) - results.append(frames) - - return results - - -def _get_mm_inputs( - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - processor: "ProcessorMixin", -) -> Dict[str, "torch.Tensor"]: - r""" - Processes visual inputs. - - Returns: (llava and paligemma) - pixel_values: tensor with shape (B, C, H, W) - - Returns: (qwen2-vl) - pixel_values: tensor with shape (num_patches, patch_dim) - image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height - - It holds num_patches == torch.prod(image_grid_thw) - """ - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - input_dict = {"images": None} # default key - if len(images) != 0: - images = _regularize_images(images, processor) - input_dict["images"] = images - - if len(videos) != 0: - videos = _regularize_videos(videos, processor) - input_dict["videos"] = videos - - if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None: - return image_processor(**input_dict, return_tensors="pt") - else: - return {} - - def _get_paligemma_token_type_ids( imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin" ) -> List[List[int]]: @@ -159,12 +60,125 @@ class BasePlugin: images: Sequence["ImageInput"], videos: Sequence["VideoInput"], ) -> None: + r""" + Validates if this model accepts the input modalities. + """ if len(images) != 0 and self.image_token is None: raise ValueError("This model does not support image input.") if len(videos) != 0 and self.video_token is None: raise ValueError("This model does not support video input.") + def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": + r""" + Pre-processes a single image. + """ + image_resolution: int = kwargs.get("image_resolution") + if max(image.width, image.height) > image_resolution: + resize_factor = image_resolution / max(image.width, image.height) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height), resample=Image.NEAREST) + + if image.mode != "RGB": + image = image.convert("RGB") + + return image + + def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int: + r""" + Computes video sample frames according to fps. + """ + video_fps: float = kwargs.get("video_fps") + video_maxlen: int = kwargs.get("video_maxlen") + total_frames = video_stream.frames + sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps + sample_frames = min(total_frames, video_maxlen, sample_frames) + return math.floor(sample_frames) + + def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]: + r""" + Regularizes images to avoid error. Including reading and pre-processing. + """ + results = [] + for image in images: + if isinstance(image, str): + image = Image.open(image) + elif isinstance(image, dict): + if image["bytes"] is not None: + image = Image.open(BytesIO(image["bytes"])) + else: + image = Image.open(image["path"]) + + if not isinstance(image, ImageObject): + raise ValueError("Expect input is a list of Images, but got {}.".format(type(image))) + + results.append(self._preprocess_image(image, **kwargs)) + + return results + + def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: + r""" + Regularizes videos to avoid error. Including reading, resizing and converting. + """ + results = [] + for video in videos: + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + total_frames = video_stream.frames + sample_frames = self._get_video_sample_frames(video_stream, **kwargs) + sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) + frames: List["ImageObject"] = [] + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) + + frames = self._regularize_images(frames, **kwargs) + results.append(frames) + + return results + + def _get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: "ProcessorMixin", + ) -> Dict[str, "torch.Tensor"]: + r""" + Processes visual inputs. + + Returns: (llava and paligemma) + pixel_values: tensor with shape (B, C, H, W) + + Returns: (qwen2-vl) + pixel_values: tensor with shape (num_patches, patch_dim) + image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height + + It holds num_patches == torch.prod(image_grid_thw) + """ + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + input_dict = {"images": None} # default key + if len(images) != 0: + images = self._regularize_images( + images, + image_resolution=getattr(processor, "image_resolution", 512), + ) + input_dict["images"] = images + + if len(videos) != 0: + videos = self._regularize_videos( + videos, + image_resolution=getattr(processor, "video_resolution", 128), + video_fps=getattr(processor, "video_fps", 1.0), + video_maxlen=getattr(processor, "video_maxlen", 64), + ) + input_dict["videos"] = videos + + if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None: + return image_processor(**input_dict, return_tensors="pt") + else: + return {} + def process_messages( self, messages: Sequence[Dict[str, str]], @@ -246,7 +260,7 @@ class LlavaPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) - return _get_mm_inputs(images, videos, processor) + return self._get_mm_inputs(images, videos, processor) class PaliGemmaPlugin(BasePlugin): @@ -305,12 +319,35 @@ class PaliGemmaPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) - mm_inputs = _get_mm_inputs(images, videos, processor) + mm_inputs = self._get_mm_inputs(images, videos, processor) mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) return mm_inputs class Qwen2vlPlugin(BasePlugin): + @override + def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": + image = super()._preprocess_image(image, **kwargs) + if min(image.width, image.height) < 28: + width, height = max(image.width, 28), max(image.height, 28) + image = image.resize((width, height), resample=Image.NEAREST) + + if image.width / image.height > 200: + width, height = image.height * 180, image.height + image = image.resize((width, height), resample=Image.NEAREST) + + if image.height / image.width > 200: + width, height = image.width, image.width * 180 + image = image.resize((width, height), resample=Image.NEAREST) + + return image + + @override + def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int: + sample_frames = super()._get_video_sample_frames(video_stream, **kwargs) + sample_frames = sample_frames // 2 * 2 + return sample_frames + @override def process_messages( self, @@ -322,7 +359,7 @@ class Qwen2vlPlugin(BasePlugin): self._validate_input(images, videos) image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") merge_length: int = getattr(image_processor, "merge_size") ** 2 - mm_inputs = _get_mm_inputs(images, videos, processor) + mm_inputs = self._get_mm_inputs(images, videos, processor) image_grid_thw = mm_inputs.get("image_grid_thw", []) video_grid_thw = mm_inputs.get("video_grid_thw", []) @@ -377,7 +414,7 @@ class Qwen2vlPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) - return _get_mm_inputs(images, videos, processor) + return self._get_mm_inputs(images, videos, processor) PLUGINS = { diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index c44468ed..030ce90f 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -103,10 +103,6 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": setattr(processor, "video_resolution", model_args.video_resolution) setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_maxlen", model_args.video_maxlen) - if getattr(config, "model_type", None) == "qwen2_vl": - setattr(processor, "video_factor", 2) - else: - setattr(processor, "video_factor", 1) except Exception: processor = None