From 20faaf341847565f38ab5c23f633680c7191b13b Mon Sep 17 00:00:00 2001 From: marko1616 Date: Thu, 26 Sep 2024 09:43:53 -0400 Subject: [PATCH] Support llama3.2vl. Former-commit-id: 3f2c056253c651e8e614c787e2045f4232e82666 --- src/llamafactory/data/mm_plugin.py | 46 +++++++++++++++++++++++++++- src/llamafactory/data/template.py | 27 ++++++++++++++++ src/llamafactory/extras/constants.py | 16 ++++++++++ 3 files changed, 88 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 6a174838..7103224b 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -4,6 +4,7 @@ from io import BytesIO from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union import numpy as np +import torch from transformers.image_utils import get_image_size, to_numpy_array from typing_extensions import override @@ -21,7 +22,6 @@ if is_pyav_available(): if TYPE_CHECKING: - import torch from av.stream import Stream from transformers import PreTrainedTokenizer, ProcessorMixin from transformers.image_processing_utils import BaseImageProcessor @@ -676,6 +676,49 @@ class VideoLlavaPlugin(BasePlugin): return self._get_mm_inputs(images, videos, processor) +class MllamaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + num_image_tokens += 1 + content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1) + + message["content"] = content + + if len(images) != num_image_tokens: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + return messages + + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + seqlens: Sequence[int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + super().get_mm_inputs(images, videos, imglens, vidlens, seqlens, processor) + if images is not None: + images = [Image.open(image) if isinstance(image, str) else image for image in images] + image_features = processor.image_processor(images) + _ = image_features.pop("num_tiles") + image_features = {k: v if isinstance(v, torch.Tensor) else torch.tensor(v) for k, v in image_features.items()} + return image_features + + PLUGINS = { "base": BasePlugin, "llava": LlavaPlugin, @@ -685,6 +728,7 @@ PLUGINS = { "pixtral": PixtralPlugin, "qwen2_vl": Qwen2vlPlugin, "video_llava": VideoLlavaPlugin, + "mllama": MllamaPlugin, } diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 27ffe9e8..045332af 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -762,6 +762,33 @@ _register_template( ) +_register_template( + name="mllama", + format_user=StringFormatter( + slots=[ + ( + "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), + format_observation=StringFormatter( + slots=[ + ( + "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|eot_id|>"], + replace_eos=True, + replace_jinja_template=False, + mm_plugin=get_mm_plugin(name="mllama", image_token=""), +) + + _register_template( name="llava", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index f6738f81..9316f230 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -855,6 +855,22 @@ register_model_group( ) +register_model_group( + models={ + "Llama-3.2-11B-Vision-Instruct": { + DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision-Instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision-Instruct", + }, + "LlamaVision3.2-90B-Instruct": { + DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision-Instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision-Instruct", + }, + }, + template="mllama", + vision=True, +) + + register_model_group( models={ "LLaVA-1.5-7B-Chat": {