fix paligemma inference

This commit is contained in:
hiyouga
2024-05-20 23:36:43 +08:00
parent 7262679666
commit 542229abb3
4 changed files with 47 additions and 20 deletions

View File

@@ -8,6 +8,7 @@ import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_TOKEN
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
@@ -55,14 +56,28 @@ class HuggingfaceEngine(BaseEngine):
image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" + messages[0]["content"]
if (
processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
): # llava case
messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
pixel_values = None
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
)
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma case
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
@@ -122,10 +137,8 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(),
)
if processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
if pixel_values is not None:
gen_kwargs["pixel_values"] = pixel_values
return gen_kwargs, prompt_length