mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
Merge branch 'hiyouga:main' into main
Former-commit-id: 4643089a7dc6a88c391663131333f35b5da5015b
This commit is contained in:
commit
4b6606832c
@ -1,6 +1,6 @@
|
||||
transformers>=4.41.2,<=4.45.0
|
||||
datasets>=2.16.0,<=2.21.0
|
||||
accelerate>=0.30.1,<=0.33.0
|
||||
accelerate>=0.30.1,<=0.34.2
|
||||
peft>=0.11.1,<=0.12.0
|
||||
trl>=0.8.6,<=0.9.6
|
||||
gradio>=4.0.0
|
||||
|
@ -22,7 +22,7 @@ Dependency graph:
|
||||
main:
|
||||
transformers>=4.41.2,<=4.45.0
|
||||
datasets>=2.16.0,<=2.21.0
|
||||
accelerate>=0.30.1,<=0.33.0
|
||||
accelerate>=0.30.1,<=0.34.2
|
||||
peft>=0.11.1,<=0.12.0
|
||||
trl>=0.8.6,<=0.9.6
|
||||
attention:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
@ -20,6 +21,7 @@ if is_pyav_available():
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from av.stream import Stream
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
@ -31,107 +33,6 @@ if TYPE_CHECKING:
|
||||
VideoInput = str
|
||||
|
||||
|
||||
def _regularize_images(
|
||||
images: Sequence["ImageInput"],
|
||||
processor: "ProcessorMixin",
|
||||
max_resolution: Optional[int] = None,
|
||||
) -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
if max_resolution is None:
|
||||
max_resolution: int = getattr(processor, "image_resolution", 512)
|
||||
|
||||
results = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
elif isinstance(image, dict):
|
||||
if image["bytes"] is not None:
|
||||
image = Image.open(BytesIO(image["bytes"]))
|
||||
else:
|
||||
image = Image.open(image["path"])
|
||||
|
||||
if not isinstance(image, ImageObject):
|
||||
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
||||
|
||||
if max(image.width, image.height) > max_resolution:
|
||||
factor = max_resolution / max(image.width, image.height)
|
||||
image = image.resize((int(image.width * factor), int(image.height * factor)), resample=Image.NEAREST)
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
results.append(image)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _regularize_videos(
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> List[List["ImageObject"]]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
video_resolution: int = getattr(processor, "video_resolution", 128)
|
||||
video_fps: float = getattr(processor, "video_fps", 1.0)
|
||||
video_maxlen: int = getattr(processor, "video_maxlen", 64)
|
||||
video_factor: int = getattr(processor, "video_factor", 1)
|
||||
results = []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
||||
sample_frames = min(video_maxlen, sample_frames) # reduce length <= maxlen
|
||||
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
|
||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
frames: List["ImageObject"] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
frames = _regularize_images(frames, processor, video_resolution)
|
||||
results.append(frames)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_mm_inputs(
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs.
|
||||
|
||||
Returns: (llava and paligemma)
|
||||
pixel_values: tensor with shape (B, C, H, W)
|
||||
|
||||
Returns: (qwen2-vl)
|
||||
pixel_values: tensor with shape (num_patches, patch_dim)
|
||||
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
||||
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
input_dict = {"images": None} # default key
|
||||
if len(images) != 0:
|
||||
images = _regularize_images(images, processor)
|
||||
input_dict["images"] = images
|
||||
|
||||
if len(videos) != 0:
|
||||
videos = _regularize_videos(videos, processor)
|
||||
input_dict["videos"] = videos
|
||||
|
||||
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
|
||||
return image_processor(**input_dict, return_tensors="pt")
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(
|
||||
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
|
||||
) -> List[List[int]]:
|
||||
@ -159,12 +60,125 @@ class BasePlugin:
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
) -> None:
|
||||
r"""
|
||||
Validates if this model accepts the input modalities.
|
||||
"""
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError("This model does not support image input.")
|
||||
|
||||
if len(videos) != 0 and self.video_token is None:
|
||||
raise ValueError("This model does not support video input.")
|
||||
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
r"""
|
||||
Pre-processes a single image.
|
||||
"""
|
||||
image_resolution: int = kwargs.get("image_resolution")
|
||||
if max(image.width, image.height) > image_resolution:
|
||||
resize_factor = image_resolution / max(image.width, image.height)
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
image = image.resize((width, height), resample=Image.NEAREST)
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
||||
r"""
|
||||
Computes video sample frames according to fps.
|
||||
"""
|
||||
video_fps: float = kwargs.get("video_fps")
|
||||
video_maxlen: int = kwargs.get("video_maxlen")
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||
return math.floor(sample_frames)
|
||||
|
||||
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including reading and pre-processing.
|
||||
"""
|
||||
results = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
elif isinstance(image, dict):
|
||||
if image["bytes"] is not None:
|
||||
image = Image.open(BytesIO(image["bytes"]))
|
||||
else:
|
||||
image = Image.open(image["path"])
|
||||
|
||||
if not isinstance(image, ImageObject):
|
||||
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
||||
|
||||
results.append(self._preprocess_image(image, **kwargs))
|
||||
|
||||
return results
|
||||
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
results = []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
frames: List["ImageObject"] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)
|
||||
results.append(frames)
|
||||
|
||||
return results
|
||||
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs.
|
||||
|
||||
Returns: (llava and paligemma)
|
||||
pixel_values: tensor with shape (B, C, H, W)
|
||||
|
||||
Returns: (qwen2-vl)
|
||||
pixel_values: tensor with shape (num_patches, patch_dim)
|
||||
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
||||
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
input_dict = {"images": None} # default key
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_resolution=getattr(processor, "image_resolution", 512),
|
||||
)
|
||||
input_dict["images"] = images
|
||||
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_resolution=getattr(processor, "video_resolution", 128),
|
||||
video_fps=getattr(processor, "video_fps", 1.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||
)
|
||||
input_dict["videos"] = videos
|
||||
|
||||
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
|
||||
return image_processor(**input_dict, return_tensors="pt")
|
||||
else:
|
||||
return {}
|
||||
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@ -290,7 +304,7 @@ class LlavaPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return _get_mm_inputs(images, videos, processor)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class LlavaNextPlugin(BasePlugin):
|
||||
@ -436,12 +450,35 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = _get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
@override
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
image = super()._preprocess_image(image, **kwargs)
|
||||
if min(image.width, image.height) < 28:
|
||||
width, height = max(image.width, 28), max(image.height, 28)
|
||||
image = image.resize((width, height), resample=Image.NEAREST)
|
||||
|
||||
if image.width / image.height > 200:
|
||||
width, height = image.height * 180, image.height
|
||||
image = image.resize((width, height), resample=Image.NEAREST)
|
||||
|
||||
if image.height / image.width > 200:
|
||||
width, height = image.width, image.width * 180
|
||||
image = image.resize((width, height), resample=Image.NEAREST)
|
||||
|
||||
return image
|
||||
|
||||
@override
|
||||
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
||||
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
|
||||
sample_frames = sample_frames // 2 * 2
|
||||
return sample_frames
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
@ -453,7 +490,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
self._validate_input(images, videos)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
mm_inputs = _get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
|
||||
@ -508,7 +545,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return _get_mm_inputs(images, videos, processor)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class VideoLlavaPlugin(BasePlugin):
|
||||
|
@ -357,6 +357,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
require_version(
|
||||
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||
)
|
||||
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
|
||||
|
||||
if data_args.template is None:
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
|
@ -81,7 +81,7 @@ def check_dependencies() -> None:
|
||||
else:
|
||||
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
|
||||
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
|
||||
require_version("accelerate>=0.30.1,<=0.33.0", "To fix: pip install accelerate>=0.30.1,<=0.33.0")
|
||||
require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user