mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
@@ -8,6 +8,7 @@ import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import IMAGE_TOKEN
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
@@ -55,14 +56,28 @@ class HuggingfaceEngine(BaseEngine):
|
||||
image: Optional["NDArray"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
|
||||
messages[0]["content"] = "<image>" + messages[0]["content"]
|
||||
if (
|
||||
processor is not None
|
||||
and image is not None
|
||||
and not hasattr(processor, "image_seq_length")
|
||||
and IMAGE_TOKEN not in messages[0]["content"]
|
||||
): # llava case
|
||||
messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or generating_args["default_system"]
|
||||
pixel_values = None
|
||||
prompt_ids, _ = template.encode_oneturn(
|
||||
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
if processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
batch_feature = image_processor(image, return_tensors="pt")
|
||||
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
|
||||
if hasattr(processor, "image_seq_length"): # paligemma case
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
|
||||
@@ -122,10 +137,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
if processor is not None and image is not None:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
|
||||
if pixel_values is not None:
|
||||
gen_kwargs["pixel_values"] = pixel_values
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import IMAGE_TOKEN
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_device_count, infer_optim_dtype
|
||||
from ..extras.packages import is_vllm_available
|
||||
@@ -17,7 +18,6 @@ if is_vllm_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
@@ -67,7 +67,7 @@ class VllmEngine(BaseEngine):
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.image_feature_size = (image_size // patch_size) ** 2
|
||||
engine_args["image_input_type"] = "pixel_values"
|
||||
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
|
||||
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
|
||||
engine_args["image_feature_size"] = self.image_feature_size
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
@@ -92,14 +92,28 @@ class VllmEngine(BaseEngine):
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
|
||||
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]
|
||||
|
||||
if (
|
||||
self.processor is not None
|
||||
and image is not None
|
||||
and not hasattr(self.processor, "image_seq_length")
|
||||
and IMAGE_TOKEN not in messages[0]["content"]
|
||||
): # llava case
|
||||
messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
prompt_ids, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
|
||||
if self.processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
use_beam_search: bool = self.generating_args["num_beams"] > 1
|
||||
@@ -144,13 +158,6 @@ class VllmEngine(BaseEngine):
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
if self.processor is not None and image is not None:
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
result_generator = self.model.generate(
|
||||
prompt=None,
|
||||
sampling_params=sampling_params,
|
||||
|
||||
Reference in New Issue
Block a user