mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
Merge pull request #5555 from marko1616/feat/llama3.2vl
Support llama3.2 vision Former-commit-id: e68ef89600e85b1f067ca6cc70459e9a7ac77b8a
This commit is contained in:
commit
8dff6f630c
3
.gitignore
vendored
3
.gitignore
vendored
@ -159,6 +159,9 @@ cython_debug/
|
|||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
|
# vscode
|
||||||
|
.vscode/
|
||||||
|
|
||||||
# custom .gitignore
|
# custom .gitignore
|
||||||
ms_cache/
|
ms_cache/
|
||||||
hf_cache/
|
hf_cache/
|
||||||
|
@ -186,6 +186,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
|
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
|
@ -187,6 +187,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
|
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
|
@ -164,7 +164,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
logits_processor=get_logits_processor(),
|
logits_processor=get_logits_processor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
|
||||||
for key, value in mm_inputs.items():
|
for key, value in mm_inputs.items():
|
||||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||||
value = torch.stack(value) # assume they have same sizes
|
value = torch.stack(value) # assume they have same sizes
|
||||||
|
@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
processor: Optional["ProcessorMixin"] = None
|
processor: Optional["ProcessorMixin"] = None
|
||||||
|
|
||||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
|
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
|
||||||
for feature in features:
|
for feature in features:
|
||||||
images = feature.pop("images", None) or []
|
images = feature.pop("images", None) or []
|
||||||
videos = feature.pop("videos", None) or []
|
videos = feature.pop("videos", None) or []
|
||||||
@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
batch_videos.extend(videos)
|
batch_videos.extend(videos)
|
||||||
batch_imglens.append(len(images))
|
batch_imglens.append(len(images))
|
||||||
batch_vidlens.append(len(videos))
|
batch_vidlens.append(len(videos))
|
||||||
batch_seqlens.append(len(feature["input_ids"]))
|
batch_input_ids.append(feature["input_ids"])
|
||||||
|
|
||||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
||||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
|
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
|
||||||
)
|
)
|
||||||
if "token_type_ids" in mm_inputs:
|
if "token_type_ids" in mm_inputs:
|
||||||
token_type_ids = mm_inputs.pop("token_type_ids")
|
token_type_ids = mm_inputs.pop("token_type_ids")
|
||||||
|
@ -4,11 +4,12 @@ from io import BytesIO
|
|||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from transformers.image_utils import get_image_size, to_numpy_array
|
from transformers.image_utils import get_image_size, to_numpy_array
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.packages import is_pillow_available, is_pyav_available
|
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
||||||
|
|
||||||
|
|
||||||
if is_pillow_available():
|
if is_pillow_available():
|
||||||
@ -20,8 +21,14 @@ if is_pyav_available():
|
|||||||
import av
|
import av
|
||||||
|
|
||||||
|
|
||||||
|
if is_transformers_version_greater_than("4.45.0"):
|
||||||
|
from transformers.models.mllama.processing_mllama import (
|
||||||
|
convert_sparse_cross_attention_mask_to_dense,
|
||||||
|
get_cross_attention_token_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
|
||||||
from av.stream import Stream
|
from av.stream import Stream
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
@ -75,8 +82,8 @@ class BasePlugin:
|
|||||||
Pre-processes a single image.
|
Pre-processes a single image.
|
||||||
"""
|
"""
|
||||||
image_resolution: int = kwargs.get("image_resolution")
|
image_resolution: int = kwargs.get("image_resolution")
|
||||||
if max(image.width, image.height) > image_resolution:
|
if image.width * image.height > image_resolution:
|
||||||
resize_factor = image_resolution / max(image.width, image.height)
|
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
|
||||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||||
image = image.resize((width, height), resample=Image.NEAREST)
|
image = image.resize((width, height), resample=Image.NEAREST)
|
||||||
|
|
||||||
@ -165,15 +172,15 @@ class BasePlugin:
|
|||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = self._regularize_images(
|
images = self._regularize_images(
|
||||||
images,
|
images,
|
||||||
image_resolution=getattr(processor, "image_resolution", 512),
|
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||||
)
|
)
|
||||||
input_dict["images"] = images
|
input_dict["images"] = images
|
||||||
|
|
||||||
if len(videos) != 0:
|
if len(videos) != 0:
|
||||||
videos = self._regularize_videos(
|
videos = self._regularize_videos(
|
||||||
videos,
|
videos,
|
||||||
image_resolution=getattr(processor, "video_resolution", 128),
|
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
||||||
video_fps=getattr(processor, "video_fps", 1.0),
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||||
)
|
)
|
||||||
input_dict["videos"] = videos
|
input_dict["videos"] = videos
|
||||||
@ -223,7 +230,7 @@ class BasePlugin:
|
|||||||
videos: Sequence["VideoInput"],
|
videos: Sequence["VideoInput"],
|
||||||
imglens: Sequence[int],
|
imglens: Sequence[int],
|
||||||
vidlens: Sequence[int],
|
vidlens: Sequence[int],
|
||||||
seqlens: Sequence[int],
|
batch_ids: Sequence[List[int]],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
r"""
|
r"""
|
||||||
@ -234,7 +241,7 @@ class BasePlugin:
|
|||||||
videos: a list of video inputs, shape (num_videos,)
|
videos: a list of video inputs, shape (num_videos,)
|
||||||
imglens: number of images in each sample, shape (batch_size,)
|
imglens: number of images in each sample, shape (batch_size,)
|
||||||
vidlens: number of videos in each sample, shape (batch_size,)
|
vidlens: number of videos in each sample, shape (batch_size,)
|
||||||
seqlens: number of tokens in each sample, shape (batch_size,)
|
batch_ids: input ids of samples, shape (batch_size, seq_len)
|
||||||
processor: a processor for pre-processing images and videos
|
processor: a processor for pre-processing images and videos
|
||||||
"""
|
"""
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
@ -258,12 +265,12 @@ class LlavaPlugin(BasePlugin):
|
|||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -274,7 +281,7 @@ class LlavaPlugin(BasePlugin):
|
|||||||
videos: Sequence["VideoInput"],
|
videos: Sequence["VideoInput"],
|
||||||
imglens: Sequence[int],
|
imglens: Sequence[int],
|
||||||
vidlens: Sequence[int],
|
vidlens: Sequence[int],
|
||||||
seqlens: Sequence[int],
|
batch_ids: Sequence[List[int]],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
@ -296,23 +303,27 @@ class LlavaNextPlugin(BasePlugin):
|
|||||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
if "image_sizes" in mm_inputs:
|
if "image_sizes" in mm_inputs:
|
||||||
image_sizes = iter(mm_inputs["image_sizes"])
|
image_sizes = iter(mm_inputs["image_sizes"])
|
||||||
|
|
||||||
if "pixel_values" in mm_inputs:
|
if "pixel_values" in mm_inputs:
|
||||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while self.image_token in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
image_size = next(image_sizes)
|
image_size = next(image_sizes)
|
||||||
orig_height, orig_width = image_size
|
orig_height, orig_width = image_size
|
||||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
if processor.vision_feature_select_strategy == "default":
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||||
image_seqlen -= 1
|
image_seqlen -= 1
|
||||||
|
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", self.image_token)
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -322,12 +333,11 @@ class LlavaNextPlugin(BasePlugin):
|
|||||||
videos: Sequence["VideoInput"],
|
videos: Sequence["VideoInput"],
|
||||||
imglens: Sequence[int],
|
imglens: Sequence[int],
|
||||||
vidlens: Sequence[int],
|
vidlens: Sequence[int],
|
||||||
seqlens: Sequence[int],
|
batch_ids: Sequence[List[int]],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
res = self._get_mm_inputs(images, videos, processor)
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextVideoPlugin(BasePlugin):
|
class LlavaNextVideoPlugin(BasePlugin):
|
||||||
@ -340,8 +350,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
num_image_tokens = 0
|
num_image_tokens, num_video_tokens = 0, 0
|
||||||
num_video_tokens = 0
|
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
if "pixel_values" in mm_inputs:
|
if "pixel_values" in mm_inputs:
|
||||||
@ -349,15 +358,15 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
while self.image_token in content:
|
|
||||||
image_size = next(image_sizes)
|
image_size = next(image_sizes)
|
||||||
orig_height, orig_width = image_size
|
orig_height, orig_width = image_size
|
||||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
if processor.vision_feature_select_strategy == "default":
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||||
image_seqlen -= 1
|
image_seqlen -= 1
|
||||||
|
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", self.image_token)
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
@ -367,19 +376,19 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||||
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while self.video_token in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
content = content.replace(self.video_token, "{{video}}", 1)
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||||
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
|
||||||
|
message["content"] = content.replace("{{video}}", self.video_token)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
if len(videos) != num_video_tokens:
|
||||||
raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -390,7 +399,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
videos: Sequence["VideoInput"],
|
videos: Sequence["VideoInput"],
|
||||||
imglens: Sequence[int],
|
imglens: Sequence[int],
|
||||||
vidlens: Sequence[int],
|
vidlens: Sequence[int],
|
||||||
seqlens: Sequence[int],
|
batch_ids: Sequence[List[int]],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
@ -418,7 +427,7 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
message["content"] = content.replace("{{image}}", "")
|
message["content"] = content.replace("{{image}}", "")
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -449,10 +458,11 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
videos: Sequence["VideoInput"],
|
videos: Sequence["VideoInput"],
|
||||||
imglens: Sequence[int],
|
imglens: Sequence[int],
|
||||||
vidlens: Sequence[int],
|
vidlens: Sequence[int],
|
||||||
seqlens: Sequence[int],
|
batch_ids: Sequence[List[int]],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
|
seqlens = [len(input_ids) for input_ids in batch_ids]
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
@ -481,7 +491,7 @@ class PixtralPlugin(BasePlugin):
|
|||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if image_input_sizes is None:
|
if image_input_sizes is None:
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
raise ValueError("Cannot get image input sizes.")
|
||||||
|
|
||||||
image_size = image_input_sizes[0][num_image_tokens]
|
image_size = image_input_sizes[0][num_image_tokens]
|
||||||
height, width = image_size
|
height, width = image_size
|
||||||
@ -497,7 +507,7 @@ class PixtralPlugin(BasePlugin):
|
|||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -508,7 +518,7 @@ class PixtralPlugin(BasePlugin):
|
|||||||
videos: Sequence["VideoInput"],
|
videos: Sequence["VideoInput"],
|
||||||
imglens: Sequence[int],
|
imglens: Sequence[int],
|
||||||
vidlens: Sequence[int],
|
vidlens: Sequence[int],
|
||||||
seqlens: Sequence[int],
|
batch_ids: Sequence[List[int]],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
@ -592,10 +602,10 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
if len(videos) != num_video_tokens:
|
||||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens")
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -606,7 +616,7 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
videos: Sequence["VideoInput"],
|
videos: Sequence["VideoInput"],
|
||||||
imglens: Sequence[int],
|
imglens: Sequence[int],
|
||||||
vidlens: Sequence[int],
|
vidlens: Sequence[int],
|
||||||
seqlens: Sequence[int],
|
batch_ids: Sequence[List[int]],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
@ -623,42 +633,45 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
num_image_tokens = 0
|
num_image_tokens, num_video_tokens = 0, 0
|
||||||
num_video_tokens = 0
|
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
num_frames = 0
|
num_frames = 0
|
||||||
exist_images = "pixel_values_images" in mm_inputs
|
has_images = "pixel_values_images" in mm_inputs
|
||||||
exist_videos = "pixel_values_videos" in mm_inputs
|
has_videos = "pixel_values_videos" in mm_inputs
|
||||||
if exist_videos or exist_images:
|
if has_images or has_videos:
|
||||||
if exist_images:
|
if has_images:
|
||||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||||
num_frames = 1
|
num_frames = 1
|
||||||
if exist_videos:
|
|
||||||
|
if has_videos:
|
||||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||||
height, width = get_image_size(pixel_values_video[0])
|
height, width = get_image_size(pixel_values_video[0])
|
||||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||||
|
|
||||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||||
video_seqlen = image_seqlen * num_frames
|
video_seqlen = image_seqlen * num_frames
|
||||||
if processor.vision_feature_select_strategy == "default":
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||||
image_seqlen -= 1
|
image_seqlen -= 1
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while self.image_token in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
content = content.replace(self.image_token, "{{image}}", 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
while self.video_token in content:
|
|
||||||
num_video_tokens += 1
|
|
||||||
content = content.replace(self.video_token, "{{video}}", 1)
|
|
||||||
|
|
||||||
content = content.replace("{{image}}", self.image_token * image_seqlen)
|
while VIDEO_PLACEHOLDER in content:
|
||||||
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
num_video_tokens += 1
|
||||||
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||||
|
|
||||||
|
content = content.replace("{{image}}", self.image_token)
|
||||||
|
message["content"] = content.replace("{{video}}", self.video_token)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError(f"The number of images does not match the number of {self.image_token} tokens")
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
if len(videos) != num_video_tokens:
|
||||||
raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens")
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -669,13 +682,86 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
videos: Sequence["VideoInput"],
|
videos: Sequence["VideoInput"],
|
||||||
imglens: Sequence[int],
|
imglens: Sequence[int],
|
||||||
vidlens: Sequence[int],
|
vidlens: Sequence[int],
|
||||||
seqlens: Sequence[int],
|
batch_ids: Sequence[List[int]],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
return self._get_mm_inputs(images, videos, processor)
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
||||||
|
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
) -> Dict[str, "torch.Tensor"]:
|
||||||
|
r"""
|
||||||
|
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pixel_values: tensor with shape
|
||||||
|
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
||||||
|
For example, (2, 1, 4, 3, 560, 560).
|
||||||
|
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
||||||
|
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
||||||
|
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
||||||
|
"""
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
|
||||||
|
return image_processor([[image] for image in images], return_tensors="pt")
|
||||||
|
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
if len(images) != len(batch_ids):
|
||||||
|
raise ValueError("Mllama only supports one image per sample.")
|
||||||
|
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
num_tiles = mm_inputs.pop("num_tiles")
|
||||||
|
image_token_id = getattr(processor, "image_token_id")
|
||||||
|
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
||||||
|
cross_attention_token_mask = [
|
||||||
|
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
||||||
|
]
|
||||||
|
mm_inputs["cross_attention_mask"] = convert_sparse_cross_attention_mask_to_dense(
|
||||||
|
cross_attention_token_mask,
|
||||||
|
num_tiles=num_tiles,
|
||||||
|
max_num_tiles=max_image_tiles,
|
||||||
|
length=max(len(input_ids) for input_ids in batch_ids),
|
||||||
|
)
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
@ -685,6 +771,7 @@ PLUGINS = {
|
|||||||
"pixtral": PixtralPlugin,
|
"pixtral": PixtralPlugin,
|
||||||
"qwen2_vl": Qwen2vlPlugin,
|
"qwen2_vl": Qwen2vlPlugin,
|
||||||
"video_llava": VideoLlavaPlugin,
|
"video_llava": VideoLlavaPlugin,
|
||||||
|
"mllama": MllamaPlugin,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -762,6 +762,33 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="mllama",
|
||||||
|
format_user=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
(
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
(
|
||||||
|
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
stop_words=["<|eot_id|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
|
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="llava",
|
name="llava",
|
||||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
|
@ -855,6 +855,22 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Llama-3.2-11B-Vision-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
},
|
||||||
|
"Llama-3.2-90B-Vision-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision-Instruct",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="mllama",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaVA-1.5-7B-Chat": {
|
"LLaVA-1.5-7B-Chat": {
|
||||||
|
@ -75,8 +75,8 @@ def is_starlette_available():
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def is_transformers_version_greater_than_4_43():
|
def is_transformers_version_greater_than(content: str):
|
||||||
return _get_package_version("transformers") >= version.parse("4.43.0")
|
return _get_package_version("transformers") >= version.parse(content)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
|
@ -59,12 +59,12 @@ class ProcessorArguments:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
image_resolution: int = field(
|
image_resolution: int = field(
|
||||||
default=512,
|
default=512 * 512,
|
||||||
metadata={"help": "Keeps the height or width of image below this resolution."},
|
metadata={"help": "Keeps the number of pixels of image below this resolution."},
|
||||||
)
|
)
|
||||||
video_resolution: int = field(
|
video_resolution: int = field(
|
||||||
default=128,
|
default=128 * 128,
|
||||||
metadata={"help": "Keeps the height or width of video below this resolution."},
|
metadata={"help": "Keeps the number of pixels of video below this resolution."},
|
||||||
)
|
)
|
||||||
video_fps: float = field(
|
video_fps: float = field(
|
||||||
default=2.0,
|
default=2.0,
|
||||||
|
@ -35,7 +35,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||||
|
|
||||||
if is_transformers_version_greater_than_4_43():
|
if is_transformers_version_greater_than("4.43.0"):
|
||||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
|
||||||
attn_output: "torch.Tensor" = _flash_attention_forward(
|
attn_output: "torch.Tensor" = _flash_attention_forward(
|
||||||
|
@ -34,13 +34,15 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
|||||||
forbidden_modules.add("output_layer")
|
forbidden_modules.add("output_layer")
|
||||||
elif model_type == "internlm2":
|
elif model_type == "internlm2":
|
||||||
forbidden_modules.add("output")
|
forbidden_modules.add("output")
|
||||||
elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
elif model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]:
|
||||||
forbidden_modules.add("multi_modal_projector")
|
forbidden_modules.add("multi_modal_projector")
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
forbidden_modules.add("merger")
|
forbidden_modules.add("merger")
|
||||||
|
|
||||||
if freeze_vision_tower:
|
if freeze_vision_tower:
|
||||||
if model_type == "qwen2_vl":
|
if model_type == "mllama":
|
||||||
|
forbidden_modules.add("vision_model")
|
||||||
|
elif model_type == "qwen2_vl":
|
||||||
forbidden_modules.add("visual")
|
forbidden_modules.add("visual")
|
||||||
else:
|
else:
|
||||||
forbidden_modules.add("vision_tower")
|
forbidden_modules.add("vision_tower")
|
||||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -115,7 +115,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
|||||||
|
|
||||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||||
if is_transformers_version_greater_than_4_43():
|
if is_transformers_version_greater_than("4.43.0"):
|
||||||
import transformers.modeling_flash_attention_utils
|
import transformers.modeling_flash_attention_utils
|
||||||
|
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||||
|
@ -26,7 +26,7 @@ from ...extras import logging
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
|
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments, ModelArguments
|
from ...hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
@ -163,19 +163,21 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
|||||||
return image_seqlen
|
return image_seqlen
|
||||||
|
|
||||||
|
|
||||||
def get_patch_size(config: "PretrainedConfig") -> int:
|
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||||
r"""
|
r"""
|
||||||
Computes the patch size of the vit.
|
Computes the patch size of the vit.
|
||||||
"""
|
"""
|
||||||
patch_size = getattr(config.vision_config, "patch_size", -1)
|
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
|
||||||
return patch_size
|
return patch_size
|
||||||
|
|
||||||
|
|
||||||
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
|
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||||
r"""
|
r"""
|
||||||
Get the vision_feature_select_strategy.
|
Get the vision_feature_select_strategy.
|
||||||
"""
|
"""
|
||||||
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
|
vision_feature_select_strategy = getattr(
|
||||||
|
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
|
||||||
|
)
|
||||||
return vision_feature_select_strategy
|
return vision_feature_select_strategy
|
||||||
|
|
||||||
|
|
||||||
@ -189,6 +191,8 @@ def patch_target_modules(
|
|||||||
if finetuning_args.freeze_vision_tower:
|
if finetuning_args.freeze_vision_tower:
|
||||||
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
|
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
|
||||||
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||||
|
elif model_type == "mllama":
|
||||||
|
return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules))
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
||||||
else:
|
else:
|
||||||
|
@ -66,11 +66,11 @@ def patch_processor(
|
|||||||
setattr(processor, "tokenizer", tokenizer)
|
setattr(processor, "tokenizer", tokenizer)
|
||||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||||
setattr(processor, "patch_size", get_patch_size(config))
|
setattr(processor, "patch_size", get_patch_size(config, processor))
|
||||||
setattr(processor, "video_resolution", model_args.video_resolution)
|
setattr(processor, "video_resolution", model_args.video_resolution)
|
||||||
setattr(processor, "video_fps", model_args.video_fps)
|
setattr(processor, "video_fps", model_args.video_fps)
|
||||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||||
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
|
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
|
||||||
|
|
||||||
|
|
||||||
def patch_config(
|
def patch_config(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user