mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
Support llama3.2vl.
Former-commit-id: 3f2c056253c651e8e614c787e2045f4232e82666
This commit is contained in:
parent
24419dd3f1
commit
20faaf3418
@ -4,6 +4,7 @@ from io import BytesIO
|
|||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from transformers.image_utils import get_image_size, to_numpy_array
|
from transformers.image_utils import get_image_size, to_numpy_array
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -21,7 +22,6 @@ if is_pyav_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
|
||||||
from av.stream import Stream
|
from av.stream import Stream
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
@ -676,6 +676,49 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
return self._get_mm_inputs(images, videos, processor)
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
seqlens: Sequence[int],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
super().get_mm_inputs(images, videos, imglens, vidlens, seqlens, processor)
|
||||||
|
if images is not None:
|
||||||
|
images = [Image.open(image) if isinstance(image, str) else image for image in images]
|
||||||
|
image_features = processor.image_processor(images)
|
||||||
|
_ = image_features.pop("num_tiles")
|
||||||
|
image_features = {k: v if isinstance(v, torch.Tensor) else torch.tensor(v) for k, v in image_features.items()}
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
@ -685,6 +728,7 @@ PLUGINS = {
|
|||||||
"pixtral": PixtralPlugin,
|
"pixtral": PixtralPlugin,
|
||||||
"qwen2_vl": Qwen2vlPlugin,
|
"qwen2_vl": Qwen2vlPlugin,
|
||||||
"video_llava": VideoLlavaPlugin,
|
"video_llava": VideoLlavaPlugin,
|
||||||
|
"mllama": MllamaPlugin,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -762,6 +762,33 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="mllama",
|
||||||
|
format_user=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
(
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
(
|
||||||
|
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
stop_words=["<|eot_id|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
|
mm_plugin=get_mm_plugin(name="mllama", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="llava",
|
name="llava",
|
||||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
|
@ -855,6 +855,22 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Llama-3.2-11B-Vision-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
},
|
||||||
|
"LlamaVision3.2-90B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision-Instruct",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="mllama",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaVA-1.5-7B-Chat": {
|
"LLaVA-1.5-7B-Chat": {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user