mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
Support llama3.2vl.
Former-commit-id: 3f2c056253c651e8e614c787e2045f4232e82666
This commit is contained in:
parent
24419dd3f1
commit
20faaf3418
@ -4,6 +4,7 @@ from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.image_utils import get_image_size, to_numpy_array
|
||||
from typing_extensions import override
|
||||
|
||||
@ -21,7 +22,6 @@ if is_pyav_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from av.stream import Stream
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
@ -676,6 +676,49 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class MllamaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
return messages
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
super().get_mm_inputs(images, videos, imglens, vidlens, seqlens, processor)
|
||||
if images is not None:
|
||||
images = [Image.open(image) if isinstance(image, str) else image for image in images]
|
||||
image_features = processor.image_processor(images)
|
||||
_ = image_features.pop("num_tiles")
|
||||
image_features = {k: v if isinstance(v, torch.Tensor) else torch.tensor(v) for k, v in image_features.items()}
|
||||
return image_features
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"llava": LlavaPlugin,
|
||||
@ -685,6 +728,7 @@ PLUGINS = {
|
||||
"pixtral": PixtralPlugin,
|
||||
"qwen2_vl": Qwen2vlPlugin,
|
||||
"video_llava": VideoLlavaPlugin,
|
||||
"mllama": MllamaPlugin,
|
||||
}
|
||||
|
||||
|
||||
|
@ -762,6 +762,33 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="mllama",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
mm_plugin=get_mm_plugin(name="mllama", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
|
@ -855,6 +855,22 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Llama-3.2-11B-Vision-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision-Instruct",
|
||||
},
|
||||
"LlamaVision3.2-90B-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision-Instruct",
|
||||
},
|
||||
},
|
||||
template="mllama",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA-1.5-7B-Chat": {
|
||||
|
Loading…
x
Reference in New Issue
Block a user