[data] add min resolution option (#6975)

Former-commit-id: 7faecc0301709326efa21e7a3fdb75fe0a9635c2
This commit is contained in:
hoshi-hiyouga 2025-02-18 01:40:46 +08:00 committed by GitHub
parent a8c9d5663d
commit b1d31ff0f9
9 changed files with 59 additions and 24 deletions

View File

@ -4,6 +4,7 @@ API_HOST=
API_PORT=
API_KEY=
API_MODEL_NAME=
API_VERBOSE=
FASTAPI_ROOT_PATH=
MAX_CONCURRENT=
# general

View File

@ -1,7 +1,7 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
image_resolution: 262144
video_resolution: 16384
image_max_resolution: 262144
video_max_resolution: 16384
trust_remote_code: true
### method

View File

@ -1,7 +1,5 @@
### model
model_name_or_path: llava-hf/llava-1.5-7b-hf
image_resolution: 262144
video_resolution: 16384
trust_remote_code: true
### method

View File

@ -1,7 +1,7 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
image_resolution: 262144
video_resolution: 16384
image_max_resolution: 262144
video_max_resolution: 16384
trust_remote_code: true
### method

View File

@ -1,7 +1,7 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
image_resolution: 262144
video_resolution: 16384
image_max_resolution: 262144
video_max_resolution: 16384
trust_remote_code: true
### method

View File

@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify
from .protocol import (
@ -70,7 +72,8 @@ ROLE_MAPPING = {
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if is_env_enabled("API_VERBOSE", "1"):
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
@ -99,10 +102,12 @@ def _process_request(
content = json.dumps(tool_calls, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
text_content = ""
for input_item in message.content:
if input_item.type == "text":
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
text_content += input_item.text
else:
text_content += IMAGE_PLACEHOLDER
image_url = input_item.image_url.url
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
@ -112,6 +117,8 @@ def _process_request(
image_stream = requests.get(image_url, stream=True).raw
images.append(Image.open(image_stream).convert("RGB"))
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})

View File

@ -104,12 +104,19 @@ class MMPluginMixin:
"This model does not support audio input. Please check whether the correct `template` is used."
)
def _preprocess_image(self, image: "ImageObject", image_resolution: int, **kwargs) -> "ImageObject":
def _preprocess_image(
self, image: "ImageObject", image_max_resolution: int, image_min_resolution: int, **kwargs
) -> "ImageObject":
r"""
Pre-processes a single image.
"""
if (image.width * image.height) > image_resolution:
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
if (image.width * image.height) > image_max_resolution:
resize_factor = math.sqrt(image_max_resolution / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
if (image.width * image.height) < image_min_resolution:
resize_factor = math.sqrt(image_min_resolution / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
@ -217,14 +224,17 @@ class MMPluginMixin:
if len(images) != 0:
images = self._regularize_images(
images, image_resolution=getattr(processor, "image_resolution", 768 * 768)
images,
image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768),
image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32),
)
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 256 * 256),
image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256),
image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
@ -606,7 +616,8 @@ class MiniCPMVPlugin(BasePlugin):
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 768 * 768),
image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768),
image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32),
)
if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
@ -626,7 +637,8 @@ class MiniCPMVPlugin(BasePlugin):
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 256 * 256),
image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256),
image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
@ -774,7 +786,11 @@ class MllamaPlugin(BasePlugin):
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 768 * 768))
images = self._regularize_images(
images,
image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768),
image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32),
)
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
@ -1065,14 +1081,17 @@ class Qwen2vlPlugin(BasePlugin):
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images, image_resolution=getattr(processor, "image_resolution", 768 * 768)
images,
image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768),
image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32),
)
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos, fps_per_video = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 256 * 256),
image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256),
image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)

View File

@ -58,14 +58,22 @@ class ProcessorArguments:
Arguments pertaining to the image processor.
"""
image_resolution: int = field(
image_max_resolution: int = field(
default=768 * 768,
metadata={"help": "The maximum number of pixels of image inputs."},
)
video_resolution: int = field(
image_min_resolution: int = field(
default=32 * 32,
metadata={"help": "The minimum number of pixels of image inputs."},
)
video_max_resolution: int = field(
default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."},
)
video_min_resolution: int = field(
default=16 * 16,
metadata={"help": "The minimum number of pixels of video inputs."},
)
video_fps: float = field(
default=2.0,
metadata={"help": "The frames to sample per second for video inputs."},

View File

@ -79,9 +79,11 @@ def patch_processor(
setattr(processor, "tokenizer", tokenizer)
if getattr(config, "vision_config", None) is not None: # visual models
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "patch_size", get_patch_size(config, processor))
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "image_max_resolution", model_args.image_max_resolution)
setattr(processor, "image_min_resolution", model_args.image_min_resolution)
setattr(processor, "video_max_resolution", model_args.video_max_resolution)
setattr(processor, "video_min_resolution", model_args.video_min_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))