modify style & little change

Former-commit-id: 9d6143e36a12e0f295139d057aeb1843535435cf
This commit is contained in:
KUANGDD 2024-10-23 15:24:07 +08:00
parent a24f94a36c
commit 62cbcb646a
7 changed files with 45 additions and 25 deletions

View File

@ -165,8 +165,17 @@ class HuggingfaceEngine(BaseEngine):
)
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
value = (
value
if isinstance(value, torch.Tensor)
else (
torch.stack(value)
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value)
else torch.tensor(value)
)
)
gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length

View File

@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update(mm_inputs)
if features.get("pixel_values") is not None and isinstance(features["pixel_values"], list):
features = features.data
return features

View File

@ -4,6 +4,7 @@ from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np
import torch
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override
@ -447,6 +448,7 @@ class PaliGemmaPlugin(BasePlugin):
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs
class PixtralPlugin(BasePlugin):
@override
def process_messages(
@ -466,32 +468,28 @@ class PixtralPlugin(BasePlugin):
img_kwargs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = None
if img_kwargs.get("pixel_values") is not None:
image_input_sizes = img_kwargs["image_sizes"]
image_input_sizes = img_kwargs.get("image_sizes", None)
messages = deepcopy(messages)
for message in messages:
content = message["content"]
img_id = 0
while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
raise ValueError(
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
)
image_size = image_input_sizes[0][img_id]
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [
[image_token] * num_width_tokens + [image_break_token]
] * num_height_tokens
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
# Flatten list
replace_tokens = [item for sublist in replace_tokens for item in sublist]
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
img_id += 1
num_image_tokens += 1
message["content"] = content
@ -514,14 +512,13 @@ class PixtralPlugin(BasePlugin):
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor)
# hack for hf engine
if mm_inputs.get("pixel_values") and len(mm_inputs.get("pixel_values")[0]) == 1:
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0][0].unsqueeze(0)
if mm_inputs.get("image_sizes"):
del mm_inputs["image_sizes"]
if mm_inputs.get("pixel_values"):
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
mm_inputs.pop("image_sizes", None)
return mm_inputs
class Qwen2vlPlugin(BasePlugin):
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
@ -698,9 +695,10 @@ def get_mm_plugin(
plugin_class = PLUGINS.get(name, None)
if plugin_class == "PixtralPlugin":
from transformers.utils.versions import require_version
try:
require_version("transformers==4.46.0.dev0")
except Exception as e:
except Exception:
raise ImportError("PixtralPlugin requires transformers>=4.46.0.dev0. Please install it first.")
if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name))

View File

@ -938,7 +938,7 @@ _register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]")
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)

View File

@ -1185,7 +1185,7 @@ register_model_group(
}
},
template="pixtral",
vision=True
vision=True,
)

View File

@ -116,7 +116,7 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
Loads model config.
"""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)

View File

@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
import pytest
import torch
@ -74,6 +74,10 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
for key in batch_a.keys():
if isinstance(batch_a[key], torch.Tensor):
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]):
assert len(batch_a[key]) == len(batch_b[key])
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
else:
assert batch_a[key] == batch_b[key]
@ -185,13 +189,19 @@ def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", ("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]")
for key, value in message.items()} for message in MM_MESSAGES
{
key: value.replace(
"<image>",
("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0]
+ "[IMG_END]",
)
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
# TODO works needed for pixtral plugin test & hack hf engine input below for now
check_inputs["expected_mm_inputs"].pop("image_sizes")
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0][0].unsqueeze(0)
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
_check_plugin(**check_inputs)