Support llama3.2vl.

Former-commit-id: 3f2c056253c651e8e614c787e2045f4232e82666
This commit is contained in:
marko1616 2024-09-26 09:43:53 -04:00 committed by hiyouga
parent 24419dd3f1
commit 20faaf3418
3 changed files with 88 additions and 1 deletions

View File

@ -4,6 +4,7 @@ from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np
import torch
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override
@ -21,7 +22,6 @@ if is_pyav_available():
if TYPE_CHECKING:
import torch
from av.stream import Stream
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
@ -676,6 +676,49 @@ class VideoLlavaPlugin(BasePlugin):
return self._get_mm_inputs(images, videos, processor)
class MllamaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
super().get_mm_inputs(images, videos, imglens, vidlens, seqlens, processor)
if images is not None:
images = [Image.open(image) if isinstance(image, str) else image for image in images]
image_features = processor.image_processor(images)
_ = image_features.pop("num_tiles")
image_features = {k: v if isinstance(v, torch.Tensor) else torch.tensor(v) for k, v in image_features.items()}
return image_features
PLUGINS = {
"base": BasePlugin,
"llava": LlavaPlugin,
@ -685,6 +728,7 @@ PLUGINS = {
"pixtral": PixtralPlugin,
"qwen2_vl": Qwen2vlPlugin,
"video_llava": VideoLlavaPlugin,
"mllama": MllamaPlugin,
}

View File

@ -762,6 +762,33 @@ _register_template(
)
_register_template(
name="mllama",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="mllama", image_token="<image>"),
)
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),

View File

@ -855,6 +855,22 @@ register_model_group(
)
register_model_group(
models={
"Llama-3.2-11B-Vision-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision-Instruct",
},
"LlamaVision3.2-90B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision-Instruct",
},
},
template="mllama",
vision=True,
)
register_model_group(
models={
"LLaVA-1.5-7B-Chat": {