From 62cbcb646ab91b050b058ca788da4ea1c3c578bc Mon Sep 17 00:00:00 2001 From: KUANGDD Date: Wed, 23 Oct 2024 15:24:07 +0800 Subject: [PATCH] modify style & little change Former-commit-id: 9d6143e36a12e0f295139d057aeb1843535435cf --- src/llamafactory/chat/hf_engine.py | 11 +++++++++- src/llamafactory/data/collator.py | 3 +++ src/llamafactory/data/mm_plugin.py | 30 +++++++++++++--------------- src/llamafactory/data/template.py | 2 +- src/llamafactory/extras/constants.py | 2 +- src/llamafactory/model/loader.py | 2 +- tests/data/test_mm_plugin.py | 20 ++++++++++++++----- 7 files changed, 45 insertions(+), 25 deletions(-) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 2b1d9fe5..53fb666a 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -165,8 +165,17 @@ class HuggingfaceEngine(BaseEngine): ) mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor) + for key, value in mm_inputs.items(): - value = value if isinstance(value, torch.Tensor) else torch.tensor(value) + value = ( + value + if isinstance(value, torch.Tensor) + else ( + torch.stack(value) + if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value) + else torch.tensor(value) + ) + ) gen_kwargs[key] = value.to(model.device) return gen_kwargs, prompt_length diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 92d86cc7..e92d2ab3 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features: Dict[str, "torch.Tensor"] = super().__call__(features) features.update(mm_inputs) + if features.get("pixel_values") is not None and isinstance(features["pixel_values"], list): + features = features.data + return features diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index f67737f5..a138c058 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -4,6 +4,7 @@ from io import BytesIO from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union import numpy as np +import torch from transformers.image_utils import get_image_size, to_numpy_array from typing_extensions import override @@ -447,6 +448,7 @@ class PaliGemmaPlugin(BasePlugin): mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) return mm_inputs + class PixtralPlugin(BasePlugin): @override def process_messages( @@ -466,32 +468,28 @@ class PixtralPlugin(BasePlugin): img_kwargs = self._get_mm_inputs(images, videos, processor) image_input_sizes = None - if img_kwargs.get("pixel_values") is not None: - image_input_sizes = img_kwargs["image_sizes"] + image_input_sizes = img_kwargs.get("image_sizes", None) messages = deepcopy(messages) for message in messages: content = message["content"] - img_id = 0 while IMAGE_PLACEHOLDER in content: - if image_input_sizes is None: - raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + raise ValueError( + "The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER) + ) - image_size = image_input_sizes[0][img_id] + image_size = image_input_sizes[0][num_image_tokens] height, width = image_size num_height_tokens = height // patch_size num_width_tokens = width // patch_size - replace_tokens = [ - [image_token] * num_width_tokens + [image_break_token] - ] * num_height_tokens + replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens # Flatten list replace_tokens = [item for sublist in replace_tokens for item in sublist] replace_tokens[-1] = image_end_token replace_str = "".join(replace_tokens) content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) - img_id += 1 num_image_tokens += 1 message["content"] = content @@ -514,14 +512,13 @@ class PixtralPlugin(BasePlugin): self._validate_input(images, videos) mm_inputs = self._get_mm_inputs(images, videos, processor) # hack for hf engine - if mm_inputs.get("pixel_values") and len(mm_inputs.get("pixel_values")[0]) == 1: - mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0][0].unsqueeze(0) - - if mm_inputs.get("image_sizes"): - del mm_inputs["image_sizes"] + if mm_inputs.get("pixel_values"): + mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0] + mm_inputs.pop("image_sizes", None) return mm_inputs + class Qwen2vlPlugin(BasePlugin): @override def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": @@ -698,9 +695,10 @@ def get_mm_plugin( plugin_class = PLUGINS.get(name, None) if plugin_class == "PixtralPlugin": from transformers.utils.versions import require_version + try: require_version("transformers==4.46.0.dev0") - except Exception as e: + except Exception: raise ImportError("PixtralPlugin requires transformers>=4.46.0.dev0. Please install it first.") if plugin_class is None: raise ValueError("Multimodal plugin `{}` not found.".format(name)) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 28ad2295..a9618885 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -938,7 +938,7 @@ _register_template( name="pixtral", format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), - mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]") + mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"), ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index cf5df20c..237afeec 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1185,7 +1185,7 @@ register_model_group( } }, template="pixtral", - vision=True + vision=True, ) diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 957d5e4e..299e6333 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -116,7 +116,7 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig": Loads model config. """ init_kwargs = _get_init_kwargs(model_args) - + return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index b342e658..66e9b57c 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple import pytest import torch @@ -74,6 +74,10 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None: for key in batch_a.keys(): if isinstance(batch_a[key], torch.Tensor): assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5) + elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]): + assert len(batch_a[key]) == len(batch_b[key]) + for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]): + assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5) else: assert batch_a[key] == batch_b[key] @@ -185,13 +189,19 @@ def test_pixtral_plugin(): image_slice_height, image_slice_width = 2, 2 check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor} check_inputs["expected_mm_messages"] = [ - {key: value.replace("", ("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]") - for key, value in message.items()} for message in MM_MESSAGES + { + key: value.replace( + "", + ("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + + "[IMG_END]", + ) + for key, value in message.items() + } + for message in MM_MESSAGES ] check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) - # TODO works needed for pixtral plugin test & hack hf engine input below for now check_inputs["expected_mm_inputs"].pop("image_sizes") - check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0][0].unsqueeze(0) + check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0] _check_plugin(**check_inputs)