modify style & little change

Former-commit-id: 9d6143e36a12e0f295139d057aeb1843535435cf
This commit is contained in:
KUANGDD 2024-10-23 15:24:07 +08:00
parent a24f94a36c
commit 62cbcb646a
7 changed files with 45 additions and 25 deletions

View File

@ -165,8 +165,17 @@ class HuggingfaceEngine(BaseEngine):
) )
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor) mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
for key, value in mm_inputs.items(): for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value) value = (
value
if isinstance(value, torch.Tensor)
else (
torch.stack(value)
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value)
else torch.tensor(value)
)
)
gen_kwargs[key] = value.to(model.device) gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length

View File

@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: Dict[str, "torch.Tensor"] = super().__call__(features) features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update(mm_inputs) features.update(mm_inputs)
if features.get("pixel_values") is not None and isinstance(features["pixel_values"], list):
features = features.data
return features return features

View File

@ -4,6 +4,7 @@ from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np import numpy as np
import torch
from transformers.image_utils import get_image_size, to_numpy_array from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override from typing_extensions import override
@ -447,6 +448,7 @@ class PaliGemmaPlugin(BasePlugin):
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs return mm_inputs
class PixtralPlugin(BasePlugin): class PixtralPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
@ -466,32 +468,28 @@ class PixtralPlugin(BasePlugin):
img_kwargs = self._get_mm_inputs(images, videos, processor) img_kwargs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = None image_input_sizes = None
if img_kwargs.get("pixel_values") is not None: image_input_sizes = img_kwargs.get("image_sizes", None)
image_input_sizes = img_kwargs["image_sizes"]
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
img_id = 0
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None: if image_input_sizes is None:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
)
image_size = image_input_sizes[0][img_id] image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size height, width = image_size
num_height_tokens = height // patch_size num_height_tokens = height // patch_size
num_width_tokens = width // patch_size num_width_tokens = width // patch_size
replace_tokens = [ replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
[image_token] * num_width_tokens + [image_break_token]
] * num_height_tokens
# Flatten list # Flatten list
replace_tokens = [item for sublist in replace_tokens for item in sublist] replace_tokens = [item for sublist in replace_tokens for item in sublist]
replace_tokens[-1] = image_end_token replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens) replace_str = "".join(replace_tokens)
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
img_id += 1
num_image_tokens += 1 num_image_tokens += 1
message["content"] = content message["content"] = content
@ -514,14 +512,13 @@ class PixtralPlugin(BasePlugin):
self._validate_input(images, videos) self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, processor)
# hack for hf engine # hack for hf engine
if mm_inputs.get("pixel_values") and len(mm_inputs.get("pixel_values")[0]) == 1: if mm_inputs.get("pixel_values"):
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0][0].unsqueeze(0) mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
if mm_inputs.get("image_sizes"):
del mm_inputs["image_sizes"]
mm_inputs.pop("image_sizes", None)
return mm_inputs return mm_inputs
class Qwen2vlPlugin(BasePlugin): class Qwen2vlPlugin(BasePlugin):
@override @override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
@ -698,9 +695,10 @@ def get_mm_plugin(
plugin_class = PLUGINS.get(name, None) plugin_class = PLUGINS.get(name, None)
if plugin_class == "PixtralPlugin": if plugin_class == "PixtralPlugin":
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
try: try:
require_version("transformers==4.46.0.dev0") require_version("transformers==4.46.0.dev0")
except Exception as e: except Exception:
raise ImportError("PixtralPlugin requires transformers>=4.46.0.dev0. Please install it first.") raise ImportError("PixtralPlugin requires transformers>=4.46.0.dev0. Please install it first.")
if plugin_class is None: if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name)) raise ValueError("Multimodal plugin `{}` not found.".format(name))

View File

@ -938,7 +938,7 @@ _register_template(
name="pixtral", name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]") mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
) )

View File

@ -1185,7 +1185,7 @@ register_model_group(
} }
}, },
template="pixtral", template="pixtral",
vision=True vision=True,
) )

View File

@ -116,7 +116,7 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
Loads model config. Loads model config.
""" """
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
import pytest import pytest
import torch import torch
@ -74,6 +74,10 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
for key in batch_a.keys(): for key in batch_a.keys():
if isinstance(batch_a[key], torch.Tensor): if isinstance(batch_a[key], torch.Tensor):
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5) assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]):
assert len(batch_a[key]) == len(batch_b[key])
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
else: else:
assert batch_a[key] == batch_b[key] assert batch_a[key] == batch_b[key]
@ -185,13 +189,19 @@ def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2 image_slice_height, image_slice_width = 2, 2
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor} check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = [ check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", ("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]") {
for key, value in message.items()} for message in MM_MESSAGES key: value.replace(
"<image>",
("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0]
+ "[IMG_END]",
)
for key, value in message.items()
}
for message in MM_MESSAGES
] ]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
# TODO works needed for pixtral plugin test & hack hf engine input below for now
check_inputs["expected_mm_inputs"].pop("image_sizes") check_inputs["expected_mm_inputs"].pop("image_sizes")
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0][0].unsqueeze(0) check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
_check_plugin(**check_inputs) _check_plugin(**check_inputs)