mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-13 16:42:48 +08:00
fix qwen2vl preprocess
Former-commit-id: c93795ae14b1d5a0a3440d18f8197fd53cd013da
This commit is contained in:
parent
3aefdad4ec
commit
c52eeb70e7
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||||
@ -11,7 +12,6 @@ from ..extras.packages import is_pillow_available, is_pyav_available
|
|||||||
|
|
||||||
if is_pillow_available():
|
if is_pillow_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.Image import Image as ImageObject
|
|
||||||
|
|
||||||
|
|
||||||
if is_pyav_available():
|
if is_pyav_available():
|
||||||
@ -20,6 +20,8 @@ if is_pyav_available():
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
import torch
|
||||||
|
from av.stream import Stream
|
||||||
|
from PIL.Image import Image as ImageObject
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
|
||||||
@ -31,107 +33,6 @@ if TYPE_CHECKING:
|
|||||||
VideoInput = str
|
VideoInput = str
|
||||||
|
|
||||||
|
|
||||||
def _regularize_images(
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
processor: "ProcessorMixin",
|
|
||||||
max_resolution: Optional[int] = None,
|
|
||||||
) -> List["ImageObject"]:
|
|
||||||
r"""
|
|
||||||
Regularizes images to avoid error. Including reading, resizing and converting.
|
|
||||||
"""
|
|
||||||
if max_resolution is None:
|
|
||||||
max_resolution: int = getattr(processor, "image_resolution", 512)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for image in images:
|
|
||||||
if isinstance(image, str):
|
|
||||||
image = Image.open(image)
|
|
||||||
elif isinstance(image, dict):
|
|
||||||
if image["bytes"] is not None:
|
|
||||||
image = Image.open(BytesIO(image["bytes"]))
|
|
||||||
else:
|
|
||||||
image = Image.open(image["path"])
|
|
||||||
|
|
||||||
if not isinstance(image, ImageObject):
|
|
||||||
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
|
||||||
|
|
||||||
if max(image.width, image.height) > max_resolution:
|
|
||||||
factor = max_resolution / max(image.width, image.height)
|
|
||||||
image = image.resize((int(image.width * factor), int(image.height * factor)), resample=Image.NEAREST)
|
|
||||||
|
|
||||||
if image.mode != "RGB":
|
|
||||||
image = image.convert("RGB")
|
|
||||||
|
|
||||||
results.append(image)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _regularize_videos(
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
processor: "ProcessorMixin",
|
|
||||||
) -> List[List["ImageObject"]]:
|
|
||||||
r"""
|
|
||||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
|
||||||
"""
|
|
||||||
video_resolution: int = getattr(processor, "video_resolution", 128)
|
|
||||||
video_fps: float = getattr(processor, "video_fps", 1.0)
|
|
||||||
video_maxlen: int = getattr(processor, "video_maxlen", 64)
|
|
||||||
video_factor: int = getattr(processor, "video_factor", 1)
|
|
||||||
results = []
|
|
||||||
for video in videos:
|
|
||||||
container = av.open(video, "r")
|
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
|
||||||
total_frames = video_stream.frames
|
|
||||||
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
|
||||||
sample_frames = min(video_maxlen, sample_frames) # reduce length <= maxlen
|
|
||||||
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
|
|
||||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
|
||||||
frames: List["ImageObject"] = []
|
|
||||||
container.seek(0)
|
|
||||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
|
||||||
if frame_idx in sample_indices:
|
|
||||||
frames.append(frame.to_image())
|
|
||||||
|
|
||||||
frames = _regularize_images(frames, processor, video_resolution)
|
|
||||||
results.append(frames)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _get_mm_inputs(
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
processor: "ProcessorMixin",
|
|
||||||
) -> Dict[str, "torch.Tensor"]:
|
|
||||||
r"""
|
|
||||||
Processes visual inputs.
|
|
||||||
|
|
||||||
Returns: (llava and paligemma)
|
|
||||||
pixel_values: tensor with shape (B, C, H, W)
|
|
||||||
|
|
||||||
Returns: (qwen2-vl)
|
|
||||||
pixel_values: tensor with shape (num_patches, patch_dim)
|
|
||||||
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
|
||||||
|
|
||||||
It holds num_patches == torch.prod(image_grid_thw)
|
|
||||||
"""
|
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
|
||||||
input_dict = {"images": None} # default key
|
|
||||||
if len(images) != 0:
|
|
||||||
images = _regularize_images(images, processor)
|
|
||||||
input_dict["images"] = images
|
|
||||||
|
|
||||||
if len(videos) != 0:
|
|
||||||
videos = _regularize_videos(videos, processor)
|
|
||||||
input_dict["videos"] = videos
|
|
||||||
|
|
||||||
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
|
|
||||||
return image_processor(**input_dict, return_tensors="pt")
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_paligemma_token_type_ids(
|
def _get_paligemma_token_type_ids(
|
||||||
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
|
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
@ -159,12 +60,125 @@ class BasePlugin:
|
|||||||
images: Sequence["ImageInput"],
|
images: Sequence["ImageInput"],
|
||||||
videos: Sequence["VideoInput"],
|
videos: Sequence["VideoInput"],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Validates if this model accepts the input modalities.
|
||||||
|
"""
|
||||||
if len(images) != 0 and self.image_token is None:
|
if len(images) != 0 and self.image_token is None:
|
||||||
raise ValueError("This model does not support image input.")
|
raise ValueError("This model does not support image input.")
|
||||||
|
|
||||||
if len(videos) != 0 and self.video_token is None:
|
if len(videos) != 0 and self.video_token is None:
|
||||||
raise ValueError("This model does not support video input.")
|
raise ValueError("This model does not support video input.")
|
||||||
|
|
||||||
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||||
|
r"""
|
||||||
|
Pre-processes a single image.
|
||||||
|
"""
|
||||||
|
image_resolution: int = kwargs.get("image_resolution")
|
||||||
|
if max(image.width, image.height) > image_resolution:
|
||||||
|
resize_factor = image_resolution / max(image.width, image.height)
|
||||||
|
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||||
|
image = image.resize((width, height), resample=Image.NEAREST)
|
||||||
|
|
||||||
|
if image.mode != "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
||||||
|
r"""
|
||||||
|
Computes video sample frames according to fps.
|
||||||
|
"""
|
||||||
|
video_fps: float = kwargs.get("video_fps")
|
||||||
|
video_maxlen: int = kwargs.get("video_maxlen")
|
||||||
|
total_frames = video_stream.frames
|
||||||
|
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
||||||
|
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||||
|
return math.floor(sample_frames)
|
||||||
|
|
||||||
|
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
|
||||||
|
r"""
|
||||||
|
Regularizes images to avoid error. Including reading and pre-processing.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for image in images:
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = Image.open(image)
|
||||||
|
elif isinstance(image, dict):
|
||||||
|
if image["bytes"] is not None:
|
||||||
|
image = Image.open(BytesIO(image["bytes"]))
|
||||||
|
else:
|
||||||
|
image = Image.open(image["path"])
|
||||||
|
|
||||||
|
if not isinstance(image, ImageObject):
|
||||||
|
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
||||||
|
|
||||||
|
results.append(self._preprocess_image(image, **kwargs))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||||
|
r"""
|
||||||
|
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for video in videos:
|
||||||
|
container = av.open(video, "r")
|
||||||
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
|
total_frames = video_stream.frames
|
||||||
|
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
||||||
|
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||||
|
frames: List["ImageObject"] = []
|
||||||
|
container.seek(0)
|
||||||
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
|
if frame_idx in sample_indices:
|
||||||
|
frames.append(frame.to_image())
|
||||||
|
|
||||||
|
frames = self._regularize_images(frames, **kwargs)
|
||||||
|
results.append(frames)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
) -> Dict[str, "torch.Tensor"]:
|
||||||
|
r"""
|
||||||
|
Processes visual inputs.
|
||||||
|
|
||||||
|
Returns: (llava and paligemma)
|
||||||
|
pixel_values: tensor with shape (B, C, H, W)
|
||||||
|
|
||||||
|
Returns: (qwen2-vl)
|
||||||
|
pixel_values: tensor with shape (num_patches, patch_dim)
|
||||||
|
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
||||||
|
|
||||||
|
It holds num_patches == torch.prod(image_grid_thw)
|
||||||
|
"""
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
input_dict = {"images": None} # default key
|
||||||
|
if len(images) != 0:
|
||||||
|
images = self._regularize_images(
|
||||||
|
images,
|
||||||
|
image_resolution=getattr(processor, "image_resolution", 512),
|
||||||
|
)
|
||||||
|
input_dict["images"] = images
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
videos = self._regularize_videos(
|
||||||
|
videos,
|
||||||
|
image_resolution=getattr(processor, "video_resolution", 128),
|
||||||
|
video_fps=getattr(processor, "video_fps", 1.0),
|
||||||
|
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||||
|
)
|
||||||
|
input_dict["videos"] = videos
|
||||||
|
|
||||||
|
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
|
||||||
|
return image_processor(**input_dict, return_tensors="pt")
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
def process_messages(
|
def process_messages(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
@ -246,7 +260,7 @@ class LlavaPlugin(BasePlugin):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
return _get_mm_inputs(images, videos, processor)
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaPlugin(BasePlugin):
|
class PaliGemmaPlugin(BasePlugin):
|
||||||
@ -305,12 +319,35 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
mm_inputs = _get_mm_inputs(images, videos, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
class Qwen2vlPlugin(BasePlugin):
|
class Qwen2vlPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||||
|
image = super()._preprocess_image(image, **kwargs)
|
||||||
|
if min(image.width, image.height) < 28:
|
||||||
|
width, height = max(image.width, 28), max(image.height, 28)
|
||||||
|
image = image.resize((width, height), resample=Image.NEAREST)
|
||||||
|
|
||||||
|
if image.width / image.height > 200:
|
||||||
|
width, height = image.height * 180, image.height
|
||||||
|
image = image.resize((width, height), resample=Image.NEAREST)
|
||||||
|
|
||||||
|
if image.height / image.width > 200:
|
||||||
|
width, height = image.width, image.width * 180
|
||||||
|
image = image.resize((width, height), resample=Image.NEAREST)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
||||||
|
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
|
||||||
|
sample_frames = sample_frames // 2 * 2
|
||||||
|
return sample_frames
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
self,
|
self,
|
||||||
@ -322,7 +359,7 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||||
mm_inputs = _get_mm_inputs(images, videos, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
|
|
||||||
@ -377,7 +414,7 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
return _get_mm_inputs(images, videos, processor)
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
|
@ -103,10 +103,6 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||||||
setattr(processor, "video_resolution", model_args.video_resolution)
|
setattr(processor, "video_resolution", model_args.video_resolution)
|
||||||
setattr(processor, "video_fps", model_args.video_fps)
|
setattr(processor, "video_fps", model_args.video_fps)
|
||||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||||
if getattr(config, "model_type", None) == "qwen2_vl":
|
|
||||||
setattr(processor, "video_factor", 2)
|
|
||||||
else:
|
|
||||||
setattr(processor, "video_factor", 1)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
processor = None
|
processor = None
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user