mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
modify style & little change
Former-commit-id: 9d6143e36a12e0f295139d057aeb1843535435cf
This commit is contained in:
parent
a24f94a36c
commit
62cbcb646a
@ -165,8 +165,17 @@ class HuggingfaceEngine(BaseEngine):
|
||||
)
|
||||
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
||||
|
||||
for key, value in mm_inputs.items():
|
||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
||||
value = (
|
||||
value
|
||||
if isinstance(value, torch.Tensor)
|
||||
else (
|
||||
torch.stack(value)
|
||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value)
|
||||
else torch.tensor(value)
|
||||
)
|
||||
)
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
features.update(mm_inputs)
|
||||
if features.get("pixel_values") is not None and isinstance(features["pixel_values"], list):
|
||||
features = features.data
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.image_utils import get_image_size, to_numpy_array
|
||||
from typing_extensions import override
|
||||
|
||||
@ -447,6 +448,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class PixtralPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
@ -466,32 +468,28 @@ class PixtralPlugin(BasePlugin):
|
||||
img_kwargs = self._get_mm_inputs(images, videos, processor)
|
||||
image_input_sizes = None
|
||||
|
||||
if img_kwargs.get("pixel_values") is not None:
|
||||
image_input_sizes = img_kwargs["image_sizes"]
|
||||
image_input_sizes = img_kwargs.get("image_sizes", None)
|
||||
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
img_id = 0
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
|
||||
if image_input_sizes is None:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(
|
||||
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
|
||||
)
|
||||
|
||||
image_size = image_input_sizes[0][img_id]
|
||||
image_size = image_input_sizes[0][num_image_tokens]
|
||||
height, width = image_size
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
replace_tokens = [
|
||||
[image_token] * num_width_tokens + [image_break_token]
|
||||
] * num_height_tokens
|
||||
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||
# Flatten list
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist]
|
||||
replace_tokens[-1] = image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||
|
||||
img_id += 1
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
@ -514,14 +512,13 @@ class PixtralPlugin(BasePlugin):
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
# hack for hf engine
|
||||
if mm_inputs.get("pixel_values") and len(mm_inputs.get("pixel_values")[0]) == 1:
|
||||
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0][0].unsqueeze(0)
|
||||
|
||||
if mm_inputs.get("image_sizes"):
|
||||
del mm_inputs["image_sizes"]
|
||||
if mm_inputs.get("pixel_values"):
|
||||
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
|
||||
|
||||
mm_inputs.pop("image_sizes", None)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
@override
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
@ -698,9 +695,10 @@ def get_mm_plugin(
|
||||
plugin_class = PLUGINS.get(name, None)
|
||||
if plugin_class == "PixtralPlugin":
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
try:
|
||||
require_version("transformers==4.46.0.dev0")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ImportError("PixtralPlugin requires transformers>=4.46.0.dev0. Please install it first.")
|
||||
if plugin_class is None:
|
||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
||||
|
@ -938,7 +938,7 @@ _register_template(
|
||||
name="pixtral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]")
|
||||
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -1185,7 +1185,7 @@ register_model_group(
|
||||
}
|
||||
},
|
||||
template="pixtral",
|
||||
vision=True
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -116,7 +116,7 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
||||
Loads model config.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
|
||||
|
||||
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -74,6 +74,10 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
|
||||
for key in batch_a.keys():
|
||||
if isinstance(batch_a[key], torch.Tensor):
|
||||
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
||||
elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]):
|
||||
assert len(batch_a[key]) == len(batch_b[key])
|
||||
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
|
||||
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
|
||||
else:
|
||||
assert batch_a[key] == batch_b[key]
|
||||
|
||||
@ -185,13 +189,19 @@ def test_pixtral_plugin():
|
||||
image_slice_height, image_slice_width = 2, 2
|
||||
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", ("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]")
|
||||
for key, value in message.items()} for message in MM_MESSAGES
|
||||
{
|
||||
key: value.replace(
|
||||
"<image>",
|
||||
("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0]
|
||||
+ "[IMG_END]",
|
||||
)
|
||||
for key, value in message.items()
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
# TODO works needed for pixtral plugin test & hack hf engine input below for now
|
||||
check_inputs["expected_mm_inputs"].pop("image_sizes")
|
||||
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0][0].unsqueeze(0)
|
||||
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user