mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
add test mm plugin
Former-commit-id: a2a8c0b92c49fb1ee65de271aec651e011dcabc4
This commit is contained in:
parent
f31e7e0dfc
commit
09a2ecebc4
@ -18,36 +18,14 @@ if TYPE_CHECKING:
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "torch.Tensor":
|
||||
def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs. (currently only supports a single image)
|
||||
Processes visual inputs.
|
||||
|
||||
Returns:
|
||||
Returns: (llava and paligemma)
|
||||
pixel_values: tensor with shape (B, C, H, W)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
return image_processor([image], return_tensors="pt")["pixel_values"]
|
||||
|
||||
|
||||
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
|
||||
r"""
|
||||
Gets paligemma token type ids for computing loss.
|
||||
|
||||
Returns:
|
||||
token_type_ids: shape (1, seq_len)
|
||||
"""
|
||||
image_seq_length = getattr(processor, "image_seq_length")
|
||||
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
|
||||
|
||||
|
||||
def get_qwen2vl_image_inputs(
|
||||
images: Sequence["ImageObject"], processor: "ProcessorMixin"
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes qwen2-vl visual inputs. Supports multiple images.
|
||||
|
||||
Returns:
|
||||
Returns: (qwen2-vl)
|
||||
pixel_values: tensor with shape (num_patches, patch_dim)
|
||||
image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height
|
||||
|
||||
@ -59,9 +37,22 @@ def get_qwen2vl_image_inputs(
|
||||
else:
|
||||
image = Image.new("RGB", (56, 56), (255, 255, 255))
|
||||
image_inputs = image_processor(images=[image], return_tensors="pt")
|
||||
image_inputs["image_grid_thw"][0][0] = 0 # fake image
|
||||
if "image_grid_thw" in image_inputs: # fake image for qwen2-vl
|
||||
image_inputs["image_grid_thw"][0][0] = 0
|
||||
|
||||
return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
|
||||
return image_inputs
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
|
||||
r"""
|
||||
Gets paligemma token type ids for computing loss.
|
||||
|
||||
Returns:
|
||||
token_type_ids: shape (1, seq_len)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_seq_length: int = getattr(image_processor, "image_seq_length")
|
||||
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
@ -131,8 +122,9 @@ class LlavaPlugin(BasePlugin):
|
||||
if image_count > 1:
|
||||
raise ValueError("Llava model only accepts one image per sample.")
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, self.image_token, 1)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
|
||||
content = content.replace("{{image}}", self.image_token)
|
||||
new_messages.append({"role": message["role"], "content": content})
|
||||
|
||||
return new_messages
|
||||
@ -143,7 +135,7 @@ class LlavaPlugin(BasePlugin):
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
return {"pixel_values": get_pixel_values(images, processor)}
|
||||
return _get_mm_inputs(images, processor)
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
@ -153,7 +145,8 @@ class LlavaPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
||||
model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
|
||||
for key, value in mm_inputs.items():
|
||||
model_inputs[key].append(value[0])
|
||||
|
||||
|
||||
class PaliGemmaPlugin(BasePlugin):
|
||||
@ -200,9 +193,9 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
mm_inputs = {"pixel_values": get_pixel_values(images, processor)}
|
||||
mm_inputs = _get_mm_inputs(images, processor)
|
||||
for feature_name, feature_length in feature_seqlens.items():
|
||||
mm_inputs[feature_name] = get_paligemma_token_type_ids(feature_length, processor)
|
||||
mm_inputs[feature_name] = _get_paligemma_token_type_ids(feature_length, processor)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@ -214,9 +207,8 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
||||
model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
|
||||
for feature_name in feature_seqlens.keys():
|
||||
model_inputs[feature_name].append(mm_inputs[feature_name][0])
|
||||
for key, value in mm_inputs.items():
|
||||
model_inputs[key].append(value[0])
|
||||
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
@ -229,7 +221,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if len(images) > 0:
|
||||
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
|
||||
image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"]
|
||||
|
||||
index = 0
|
||||
new_messages = []
|
||||
@ -255,7 +247,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
return get_qwen2vl_image_inputs(images, processor)
|
||||
return _get_mm_inputs(images, processor)
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
@ -265,11 +257,12 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
||||
model_inputs["pixel_values"].append(mm_inputs["pixel_values"])
|
||||
model_inputs["image_grid_thw"].append(mm_inputs["image_grid_thw"])
|
||||
for key, value in mm_inputs.items():
|
||||
model_inputs[key].append(value) # support multi-image
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"llava": LlavaPlugin,
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
"qwen2_vl": Qwen2vlPlugin,
|
||||
|
@ -19,13 +19,14 @@ from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .mm_plugin import BasePlugin, get_mm_plugin
|
||||
from .mm_plugin import get_mm_plugin
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .formatter import SLOTS, Formatter
|
||||
from .mm_plugin import BasePlugin
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -209,7 +210,7 @@ def _register_template(
|
||||
stop_words: Sequence[str] = [],
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
mm_plugin: "BasePlugin" = BasePlugin(IMAGE_PLACEHOLDER),
|
||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base", image_token=IMAGE_PLACEHOLDER),
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
|
@ -99,6 +99,11 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
except Exception:
|
||||
processor = None
|
||||
|
||||
# Avoid load tokenizer, see:
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
|
||||
if "Processor" not in processor.__class__.__name__:
|
||||
processor = None
|
||||
|
||||
return {"tokenizer": tokenizer, "processor": processor}
|
||||
|
||||
|
||||
|
@ -46,7 +46,6 @@ def save_model(
|
||||
finetuning_type: str,
|
||||
checkpoint_path: Union[str, List[str]],
|
||||
template: str,
|
||||
visual_inputs: bool,
|
||||
export_size: int,
|
||||
export_quantization_bit: str,
|
||||
export_quantization_dataset: str,
|
||||
@ -78,7 +77,6 @@ def save_model(
|
||||
model_name_or_path=model_path,
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
visual_inputs=visual_inputs,
|
||||
export_dir=export_dir,
|
||||
export_hub_model_id=export_hub_model_id or None,
|
||||
export_size=export_size,
|
||||
@ -129,7 +127,6 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_id("top.checkpoint_path"),
|
||||
engine.manager.get_elem_by_id("top.template"),
|
||||
engine.manager.get_elem_by_id("top.visual_inputs"),
|
||||
export_size,
|
||||
export_quantization_bit,
|
||||
export_quantization_dataset,
|
||||
|
151
tests/data/test_mm_plugin.py
Normal file
151
tests/data/test_mm_plugin.py
Normal file
@ -0,0 +1,151 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from llamafactory.data.mm_plugin import get_mm_plugin
|
||||
from llamafactory.hparams import ModelArguments
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
MESSAGES = [
|
||||
{"role": "user", "content": "<image>What is in this image?"},
|
||||
{"role": "assistant", "content": "A cat."},
|
||||
]
|
||||
|
||||
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
|
||||
|
||||
INPUT_IDS = [0, 1, 2, 3, 4]
|
||||
|
||||
LABELS = [0, 1, 2, 3, 4]
|
||||
|
||||
FEATURE_SEQLENS = {"token_type_ids": 1024}
|
||||
|
||||
|
||||
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
return image_processor(images=IMAGES, return_tensors="pt")
|
||||
|
||||
|
||||
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]):
|
||||
assert batch_a.keys() == batch_b.keys()
|
||||
for key in batch_a.keys():
|
||||
if isinstance(batch_a[key], list):
|
||||
assert len(batch_a[key]) == len(batch_b[key])
|
||||
for i in range(len(batch_a[key])):
|
||||
if isinstance(batch_a[key][i], torch.Tensor):
|
||||
assert torch.allclose(batch_a[key][i], batch_b[key][i], rtol=1e-4, atol=1e-5)
|
||||
else:
|
||||
assert batch_a[key][i] == batch_b[key][i]
|
||||
elif isinstance(batch_a[key], torch.Tensor):
|
||||
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_base_plugin():
|
||||
model_args = ModelArguments(model_name_or_path=TINY_LLAMA)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
processor = tokenizer_module["processor"]
|
||||
|
||||
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
||||
model_inputs = defaultdict(list)
|
||||
base_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
||||
|
||||
assert base_plugin.process_messages(MESSAGES, IMAGES, processor)
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {})
|
||||
_is_close(model_inputs, {})
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
model_args = ModelArguments(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
processor = tokenizer_module["processor"]
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()}
|
||||
|
||||
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
|
||||
model_inputs = defaultdict(list)
|
||||
llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
||||
|
||||
assert llava_plugin.process_messages(MESSAGES, IMAGES, processor)
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(model_inputs, expected_model_inputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_paligemma_plugin():
|
||||
model_args = ModelArguments(model_name_or_path="google/paligemma-3b-pt-224")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
processor = tokenizer_module["processor"]
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_seq_length: int = getattr(image_processor, "image_seq_length")
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
mm_inputs["token_type_ids"] = [[0] * image_seq_length + [1] * (1024 - image_seq_length)]
|
||||
expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()}
|
||||
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seq_length + INPUT_IDS
|
||||
expected_labels = [-100] * image_seq_length + LABELS
|
||||
|
||||
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
|
||||
model_inputs = defaultdict(list)
|
||||
paligemma_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
||||
|
||||
assert paligemma_plugin.process_messages(MESSAGES, IMAGES, processor)
|
||||
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(model_inputs, expected_model_inputs)
|
||||
|
||||
|
||||
def test_qwen2_vl_plugin():
|
||||
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
processor = tokenizer_module["processor"]
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
expected_model_inputs = {key: [value] for key, value in mm_inputs.items()}
|
||||
|
||||
llava_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||
model_inputs = defaultdict(list)
|
||||
llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
||||
|
||||
assert llava_plugin.process_messages(MESSAGES, IMAGES, processor)
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(model_inputs, expected_model_inputs)
|
Loading…
x
Reference in New Issue
Block a user