diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index cd2604d8..7631c937 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -18,36 +18,14 @@ if TYPE_CHECKING: from transformers.image_processing_utils import BaseImageProcessor -def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "torch.Tensor": +def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: r""" - Processes visual inputs. (currently only supports a single image) + Processes visual inputs. - Returns: + Returns: (llava and paligemma) pixel_values: tensor with shape (B, C, H, W) - """ - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255)) - return image_processor([image], return_tensors="pt")["pixel_values"] - -def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]: - r""" - Gets paligemma token type ids for computing loss. - - Returns: - token_type_ids: shape (1, seq_len) - """ - image_seq_length = getattr(processor, "image_seq_length") - return [[0] * image_seq_length + [1] * (input_len - image_seq_length)] - - -def get_qwen2vl_image_inputs( - images: Sequence["ImageObject"], processor: "ProcessorMixin" -) -> Dict[str, "torch.Tensor"]: - r""" - Processes qwen2-vl visual inputs. Supports multiple images. - - Returns: + Returns: (qwen2-vl) pixel_values: tensor with shape (num_patches, patch_dim) image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height @@ -59,9 +37,22 @@ def get_qwen2vl_image_inputs( else: image = Image.new("RGB", (56, 56), (255, 255, 255)) image_inputs = image_processor(images=[image], return_tensors="pt") - image_inputs["image_grid_thw"][0][0] = 0 # fake image + if "image_grid_thw" in image_inputs: # fake image for qwen2-vl + image_inputs["image_grid_thw"][0][0] = 0 - return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]} + return image_inputs + + +def _get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]: + r""" + Gets paligemma token type ids for computing loss. + + Returns: + token_type_ids: shape (1, seq_len) + """ + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image_seq_length: int = getattr(image_processor, "image_seq_length") + return [[0] * image_seq_length + [1] * (input_len - image_seq_length)] class BasePlugin: @@ -131,8 +122,9 @@ class LlavaPlugin(BasePlugin): if image_count > 1: raise ValueError("Llava model only accepts one image per sample.") - content = content.replace(IMAGE_PLACEHOLDER, self.image_token, 1) + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + content = content.replace("{{image}}", self.image_token) new_messages.append({"role": message["role"], "content": content}) return new_messages @@ -143,7 +135,7 @@ class LlavaPlugin(BasePlugin): feature_seqlens: Dict[str, int], processor: Optional["ProcessorMixin"], ) -> Dict[str, Any]: - return {"pixel_values": get_pixel_values(images, processor)} + return _get_mm_inputs(images, processor) def process_model_inputs( self, @@ -153,7 +145,8 @@ class LlavaPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> None: mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor) - model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0]) + for key, value in mm_inputs.items(): + model_inputs[key].append(value[0]) class PaliGemmaPlugin(BasePlugin): @@ -200,9 +193,9 @@ class PaliGemmaPlugin(BasePlugin): feature_seqlens: Dict[str, int], processor: Optional["ProcessorMixin"], ) -> Dict[str, Any]: - mm_inputs = {"pixel_values": get_pixel_values(images, processor)} + mm_inputs = _get_mm_inputs(images, processor) for feature_name, feature_length in feature_seqlens.items(): - mm_inputs[feature_name] = get_paligemma_token_type_ids(feature_length, processor) + mm_inputs[feature_name] = _get_paligemma_token_type_ids(feature_length, processor) return mm_inputs @@ -214,9 +207,8 @@ class PaliGemmaPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> None: mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor) - model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0]) - for feature_name in feature_seqlens.keys(): - model_inputs[feature_name].append(mm_inputs[feature_name][0]) + for key, value in mm_inputs.items(): + model_inputs[key].append(value[0]) class Qwen2vlPlugin(BasePlugin): @@ -229,7 +221,7 @@ class Qwen2vlPlugin(BasePlugin): image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") merge_length: int = getattr(image_processor, "merge_size") ** 2 if len(images) > 0: - image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"] + image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"] index = 0 new_messages = [] @@ -255,7 +247,7 @@ class Qwen2vlPlugin(BasePlugin): feature_seqlens: Dict[str, int], processor: Optional["ProcessorMixin"], ) -> Dict[str, Any]: - return get_qwen2vl_image_inputs(images, processor) + return _get_mm_inputs(images, processor) def process_model_inputs( self, @@ -265,11 +257,12 @@ class Qwen2vlPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> None: mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor) - model_inputs["pixel_values"].append(mm_inputs["pixel_values"]) - model_inputs["image_grid_thw"].append(mm_inputs["image_grid_thw"]) + for key, value in mm_inputs.items(): + model_inputs[key].append(value) # support multi-image PLUGINS = { + "base": BasePlugin, "llava": LlavaPlugin, "paligemma": PaliGemmaPlugin, "qwen2_vl": Qwen2vlPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index fe0a104f..63564e8f 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -19,13 +19,14 @@ from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.logging import get_logger from .data_utils import Role from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter -from .mm_plugin import BasePlugin, get_mm_plugin +from .mm_plugin import get_mm_plugin if TYPE_CHECKING: from transformers import PreTrainedTokenizer from .formatter import SLOTS, Formatter + from .mm_plugin import BasePlugin logger = get_logger(__name__) @@ -209,7 +210,7 @@ def _register_template( stop_words: Sequence[str] = [], efficient_eos: bool = False, replace_eos: bool = False, - mm_plugin: "BasePlugin" = BasePlugin(IMAGE_PLACEHOLDER), + mm_plugin: "BasePlugin" = get_mm_plugin(name="base", image_token=IMAGE_PLACEHOLDER), ) -> None: r""" Registers a chat template. diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 9a16c0ce..374748d0 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -99,6 +99,11 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": except Exception: processor = None + # Avoid load tokenizer, see: + # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324 + if "Processor" not in processor.__class__.__name__: + processor = None + return {"tokenizer": tokenizer, "processor": processor} diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index d9d4b937..190e8855 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -46,7 +46,6 @@ def save_model( finetuning_type: str, checkpoint_path: Union[str, List[str]], template: str, - visual_inputs: bool, export_size: int, export_quantization_bit: str, export_quantization_dataset: str, @@ -78,7 +77,6 @@ def save_model( model_name_or_path=model_path, finetuning_type=finetuning_type, template=template, - visual_inputs=visual_inputs, export_dir=export_dir, export_hub_model_id=export_hub_model_id or None, export_size=export_size, @@ -129,7 +127,6 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: engine.manager.get_elem_by_id("top.finetuning_type"), engine.manager.get_elem_by_id("top.checkpoint_path"), engine.manager.get_elem_by_id("top.template"), - engine.manager.get_elem_by_id("top.visual_inputs"), export_size, export_quantization_bit, export_quantization_dataset, diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py new file mode 100644 index 00000000..52061075 --- /dev/null +++ b/tests/data/test_mm_plugin.py @@ -0,0 +1,151 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict + +import pytest +import torch +from PIL import Image + +from llamafactory.data.mm_plugin import get_mm_plugin +from llamafactory.hparams import ModelArguments +from llamafactory.model import load_tokenizer + + +if TYPE_CHECKING: + from transformers import ProcessorMixin + from transformers.image_processing_utils import BaseImageProcessor + + +HF_TOKEN = os.environ.get("HF_TOKEN", None) + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +MESSAGES = [ + {"role": "user", "content": "What is in this image?"}, + {"role": "assistant", "content": "A cat."}, +] + +IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))] + +INPUT_IDS = [0, 1, 2, 3, 4] + +LABELS = [0, 1, 2, 3, 4] + +FEATURE_SEQLENS = {"token_type_ids": 1024} + + +def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + return image_processor(images=IMAGES, return_tensors="pt") + + +def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]): + assert batch_a.keys() == batch_b.keys() + for key in batch_a.keys(): + if isinstance(batch_a[key], list): + assert len(batch_a[key]) == len(batch_b[key]) + for i in range(len(batch_a[key])): + if isinstance(batch_a[key][i], torch.Tensor): + assert torch.allclose(batch_a[key][i], batch_b[key][i], rtol=1e-4, atol=1e-5) + else: + assert batch_a[key][i] == batch_b[key][i] + elif isinstance(batch_a[key], torch.Tensor): + assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5) + else: + raise NotImplementedError + + +def test_base_plugin(): + model_args = ModelArguments(model_name_or_path=TINY_LLAMA) + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + processor = tokenizer_module["processor"] + + base_plugin = get_mm_plugin(name="base", image_token="") + model_inputs = defaultdict(list) + base_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) + + assert base_plugin.process_messages(MESSAGES, IMAGES, processor) + assert base_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS) + _is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {}) + _is_close(model_inputs, {}) + + +def test_llava_plugin(): + model_args = ModelArguments(model_name_or_path="llava-hf/llava-1.5-7b-hf") + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + processor = tokenizer_module["processor"] + + mm_inputs = _get_mm_inputs(processor) + expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()} + + llava_plugin = get_mm_plugin(name="llava", image_token="") + model_inputs = defaultdict(list) + llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) + + assert llava_plugin.process_messages(MESSAGES, IMAGES, processor) + assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS) + _is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) + _is_close(model_inputs, expected_model_inputs) + + +@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") +def test_paligemma_plugin(): + model_args = ModelArguments(model_name_or_path="google/paligemma-3b-pt-224") + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + processor = tokenizer_module["processor"] + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image_seq_length: int = getattr(image_processor, "image_seq_length") + + mm_inputs = _get_mm_inputs(processor) + mm_inputs["token_type_ids"] = [[0] * image_seq_length + [1] * (1024 - image_seq_length)] + expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()} + expected_input_ids = [tokenizer.convert_tokens_to_ids("")] * image_seq_length + INPUT_IDS + expected_labels = [-100] * image_seq_length + LABELS + + paligemma_plugin = get_mm_plugin(name="paligemma", image_token="") + model_inputs = defaultdict(list) + paligemma_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) + + assert paligemma_plugin.process_messages(MESSAGES, IMAGES, processor) + assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == ( + expected_input_ids, + expected_labels, + ) + _is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) + _is_close(model_inputs, expected_model_inputs) + + +def test_qwen2_vl_plugin(): + model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + processor = tokenizer_module["processor"] + + mm_inputs = _get_mm_inputs(processor) + expected_model_inputs = {key: [value] for key, value in mm_inputs.items()} + + llava_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>") + model_inputs = defaultdict(list) + llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) + + assert llava_plugin.process_messages(MESSAGES, IMAGES, processor) + assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS) + _is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) + _is_close(model_inputs, expected_model_inputs)