mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[infer] vllm video/audio inference (#7566)
This commit is contained in:
		
							parent
							
								
									2bfcad2394
								
							
						
					
					
						commit
						5e22597ff1
					
				@ -92,8 +92,20 @@ def vllm_infer(
 | 
			
		||||
            multi_modal_data = {
 | 
			
		||||
                "image": template_obj.mm_plugin._regularize_images(
 | 
			
		||||
                    sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
 | 
			
		||||
                )
 | 
			
		||||
                )["images"]
 | 
			
		||||
            }
 | 
			
		||||
        elif sample["videos"]:
 | 
			
		||||
            multi_modal_data = {
 | 
			
		||||
                "video": template_obj.mm_plugin._regularize_videos(
 | 
			
		||||
                    sample["videos"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
 | 
			
		||||
                )["videos"]
 | 
			
		||||
            }
 | 
			
		||||
        elif sample["audios"]:
 | 
			
		||||
            audio_data = template_obj.mm_plugin._regularize_audios(
 | 
			
		||||
                sample["audios"],
 | 
			
		||||
                sampling_rate=16000,
 | 
			
		||||
            )
 | 
			
		||||
            multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
 | 
			
		||||
        else:
 | 
			
		||||
            multi_modal_data = None
 | 
			
		||||
 | 
			
		||||
@ -131,7 +143,7 @@ def vllm_infer(
 | 
			
		||||
        "enable_lora": model_args.adapter_name_or_path is not None,
 | 
			
		||||
    }
 | 
			
		||||
    if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
 | 
			
		||||
        engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
 | 
			
		||||
        engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
 | 
			
		||||
 | 
			
		||||
    if isinstance(model_args.vllm_config, dict):
 | 
			
		||||
        engine_args.update(model_args.vllm_config)
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Optional
 | 
			
		||||
 | 
			
		||||
from ..data import Role as DataRole
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER
 | 
			
		||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.misc import is_env_enabled
 | 
			
		||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
 | 
			
		||||
from .common import dictify, jsonify
 | 
			
		||||
@ -56,7 +56,7 @@ if is_requests_available():
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from ..chat import ChatModel
 | 
			
		||||
    from ..data.mm_plugin import ImageInput
 | 
			
		||||
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
 | 
			
		||||
    from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -72,7 +72,14 @@ ROLE_MAPPING = {
 | 
			
		||||
 | 
			
		||||
def _process_request(
 | 
			
		||||
    request: "ChatCompletionRequest",
 | 
			
		||||
) -> tuple[list[dict[str, str]], Optional[str], Optional[str], Optional[list["ImageInput"]]]:
 | 
			
		||||
) -> tuple[
 | 
			
		||||
    list[dict[str, str]],
 | 
			
		||||
    Optional[str],
 | 
			
		||||
    Optional[str],
 | 
			
		||||
    Optional[list["ImageInput"]],
 | 
			
		||||
    Optional[list["VideoInput"]],
 | 
			
		||||
    Optional[list["AudioInput"]],
 | 
			
		||||
]:
 | 
			
		||||
    if is_env_enabled("API_VERBOSE", "1"):
 | 
			
		||||
        logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
 | 
			
		||||
 | 
			
		||||
@ -88,7 +95,7 @@ def _process_request(
 | 
			
		||||
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
 | 
			
		||||
 | 
			
		||||
    input_messages = []
 | 
			
		||||
    images = []
 | 
			
		||||
    images, videos, audios = [], [], []
 | 
			
		||||
    for i, message in enumerate(request.messages):
 | 
			
		||||
        if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
 | 
			
		||||
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
 | 
			
		||||
@ -107,7 +114,7 @@ def _process_request(
 | 
			
		||||
            for input_item in message.content:
 | 
			
		||||
                if input_item.type == "text":
 | 
			
		||||
                    text_content += input_item.text
 | 
			
		||||
                else:
 | 
			
		||||
                elif input_item.type == "image_url":
 | 
			
		||||
                    text_content += IMAGE_PLACEHOLDER
 | 
			
		||||
                    image_url = input_item.image_url.url
 | 
			
		||||
                    if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url):  # base64 image
 | 
			
		||||
@ -118,6 +125,28 @@ def _process_request(
 | 
			
		||||
                        image_stream = requests.get(image_url, stream=True).raw
 | 
			
		||||
 | 
			
		||||
                    images.append(Image.open(image_stream).convert("RGB"))
 | 
			
		||||
                elif input_item.type == "video_url":
 | 
			
		||||
                    text_content += VIDEO_PLACEHOLDER
 | 
			
		||||
                    video_url = input_item.video_url.url
 | 
			
		||||
                    if os.path.isfile(video_url):  # local file
 | 
			
		||||
                        video_stream = open(video_url, "rb")
 | 
			
		||||
                    else:  # web uri
 | 
			
		||||
                        video_stream = requests.get(video_url, stream=True).raw
 | 
			
		||||
 | 
			
		||||
                    videos.append(video_stream)
 | 
			
		||||
                elif input_item.type == "audio_url":
 | 
			
		||||
                    text_content += AUDIO_PLACEHOLDER
 | 
			
		||||
                    audio_url = input_item.audio_url.url
 | 
			
		||||
                    if os.path.isfile(audio_url):  # local file
 | 
			
		||||
                        audio_stream = open(audio_url, "rb")
 | 
			
		||||
                    else:  # web uri
 | 
			
		||||
                        audio_stream = requests.get(audio_url, stream=True).raw
 | 
			
		||||
 | 
			
		||||
                    audios.append(audio_stream)
 | 
			
		||||
                else:
 | 
			
		||||
                    raise HTTPException(
 | 
			
		||||
                        status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
 | 
			
		||||
        else:
 | 
			
		||||
@ -132,7 +161,7 @@ def _process_request(
 | 
			
		||||
    else:
 | 
			
		||||
        tools = None
 | 
			
		||||
 | 
			
		||||
    return input_messages, system, tools, images or None
 | 
			
		||||
    return input_messages, system, tools, images or None, videos or None, audios or None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_stream_chat_completion_chunk(
 | 
			
		||||
@ -151,12 +180,14 @@ async def create_chat_completion_response(
 | 
			
		||||
    request: "ChatCompletionRequest", chat_model: "ChatModel"
 | 
			
		||||
) -> "ChatCompletionResponse":
 | 
			
		||||
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
 | 
			
		||||
    input_messages, system, tools, images = _process_request(request)
 | 
			
		||||
    input_messages, system, tools, images, videos, audios = _process_request(request)
 | 
			
		||||
    responses = await chat_model.achat(
 | 
			
		||||
        input_messages,
 | 
			
		||||
        system,
 | 
			
		||||
        tools,
 | 
			
		||||
        images,
 | 
			
		||||
        videos,
 | 
			
		||||
        audios,
 | 
			
		||||
        do_sample=request.do_sample,
 | 
			
		||||
        temperature=request.temperature,
 | 
			
		||||
        top_p=request.top_p,
 | 
			
		||||
@ -202,7 +233,7 @@ async def create_stream_chat_completion_response(
 | 
			
		||||
    request: "ChatCompletionRequest", chat_model: "ChatModel"
 | 
			
		||||
) -> AsyncGenerator[str, None]:
 | 
			
		||||
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
 | 
			
		||||
    input_messages, system, tools, images = _process_request(request)
 | 
			
		||||
    input_messages, system, tools, images, videos, audios = _process_request(request)
 | 
			
		||||
    if tools:
 | 
			
		||||
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
 | 
			
		||||
 | 
			
		||||
@ -217,6 +248,8 @@ async def create_stream_chat_completion_response(
 | 
			
		||||
        system,
 | 
			
		||||
        tools,
 | 
			
		||||
        images,
 | 
			
		||||
        videos,
 | 
			
		||||
        audios,
 | 
			
		||||
        do_sample=request.do_sample,
 | 
			
		||||
        temperature=request.temperature,
 | 
			
		||||
        top_p=request.top_p,
 | 
			
		||||
 | 
			
		||||
@ -70,14 +70,17 @@ class FunctionCall(BaseModel):
 | 
			
		||||
    function: Function
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImageURL(BaseModel):
 | 
			
		||||
class URL(BaseModel):
 | 
			
		||||
    url: str
 | 
			
		||||
    detail: Literal["auto", "low", "high"] = "auto"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MultimodalInputItem(BaseModel):
 | 
			
		||||
    type: Literal["text", "image_url"]
 | 
			
		||||
    type: Literal["text", "image_url", "video_url", "audio_url"]
 | 
			
		||||
    text: Optional[str] = None
 | 
			
		||||
    image_url: Optional[ImageURL] = None
 | 
			
		||||
    image_url: Optional[URL] = None
 | 
			
		||||
    video_url: Optional[URL] = None
 | 
			
		||||
    audio_url: Optional[URL] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatMessage(BaseModel):
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,7 @@ from .base_engine import BaseEngine, Response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_sglang_available():
 | 
			
		||||
    from sglang.utils import launch_server_cmd, terminate_process, wait_for_server
 | 
			
		||||
    from sglang.utils import launch_server_cmd, terminate_process, wait_for_server  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -134,24 +134,17 @@ class SGLangEngine(BaseEngine):
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncIterator[dict[str, Any]]:
 | 
			
		||||
        mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
 | 
			
		||||
        if images is not None:
 | 
			
		||||
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
 | 
			
		||||
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
 | 
			
		||||
        if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
            messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        if videos is not None:
 | 
			
		||||
            mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
 | 
			
		||||
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
 | 
			
		||||
        if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
            messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        if audios is not None:
 | 
			
		||||
            mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
 | 
			
		||||
            if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
 | 
			
		||||
        if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
            messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        messages = self.template.mm_plugin.process_messages(
 | 
			
		||||
            messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
 | 
			
		||||
            messages, images or [], videos or [], audios or [], self.processor
 | 
			
		||||
        )
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        system = system or self.generating_args["default_system"]
 | 
			
		||||
 | 
			
		||||
@ -83,7 +83,7 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
            "max_lora_rank": model_args.vllm_max_lora_rank,
 | 
			
		||||
        }
 | 
			
		||||
        if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
 | 
			
		||||
            engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
 | 
			
		||||
            engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
 | 
			
		||||
 | 
			
		||||
        if isinstance(model_args.vllm_config, dict):
 | 
			
		||||
            engine_args.update(model_args.vllm_config)
 | 
			
		||||
@ -111,24 +111,17 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncIterator["RequestOutput"]:
 | 
			
		||||
        request_id = f"chatcmpl-{uuid.uuid4().hex}"
 | 
			
		||||
        mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
 | 
			
		||||
        if images is not None:
 | 
			
		||||
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
 | 
			
		||||
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
 | 
			
		||||
        if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
            messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        if videos is not None:
 | 
			
		||||
            mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
 | 
			
		||||
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
 | 
			
		||||
        if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
            messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        if audios is not None:
 | 
			
		||||
            mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
 | 
			
		||||
            if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
 | 
			
		||||
        if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
            messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        messages = self.template.mm_plugin.process_messages(
 | 
			
		||||
            messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
 | 
			
		||||
            messages, images or [], videos or [], audios or [], self.processor
 | 
			
		||||
        )
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        system = system or self.generating_args["default_system"]
 | 
			
		||||
@ -186,8 +179,24 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
                    images,
 | 
			
		||||
                    image_max_pixels=self.model_args.image_max_pixels,
 | 
			
		||||
                    image_min_pixels=self.model_args.image_min_pixels,
 | 
			
		||||
                )
 | 
			
		||||
                )["images"]
 | 
			
		||||
            }
 | 
			
		||||
        elif videos is not None:
 | 
			
		||||
            multi_modal_data = {
 | 
			
		||||
                "video": self.template.mm_plugin._regularize_videos(
 | 
			
		||||
                    videos,
 | 
			
		||||
                    image_max_pixels=self.model_args.video_max_pixels,
 | 
			
		||||
                    image_min_pixels=self.model_args.video_min_pixels,
 | 
			
		||||
                    video_fps=self.model_args.video_fps,
 | 
			
		||||
                    video_maxlen=self.model_args.video_maxlen,
 | 
			
		||||
                )["videos"]
 | 
			
		||||
            }
 | 
			
		||||
        elif audios is not None:
 | 
			
		||||
            audio_data = self.template.mm_plugin._regularize_audios(
 | 
			
		||||
                audios,
 | 
			
		||||
                sampling_rate=self.model_args.audio_sampling_rate,
 | 
			
		||||
            )
 | 
			
		||||
            multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
 | 
			
		||||
        else:
 | 
			
		||||
            multi_modal_data = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -26,8 +26,12 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import Seq2SeqTrainingArguments
 | 
			
		||||
 | 
			
		||||
    from ..hparams import DataArguments
 | 
			
		||||
    from .mm_plugin import AudioInput, ImageInput, VideoInput
 | 
			
		||||
    from .parser import DatasetAttr
 | 
			
		||||
 | 
			
		||||
    MediaType = Union[ImageInput, VideoInput, AudioInput]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -36,10 +40,12 @@ class DatasetConverter:
 | 
			
		||||
    dataset_attr: "DatasetAttr"
 | 
			
		||||
    data_args: "DataArguments"
 | 
			
		||||
 | 
			
		||||
    def _find_medias(self, medias: Union[Any, list[Any]]) -> Optional[list[Any]]:
 | 
			
		||||
    def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
 | 
			
		||||
        r"""Optionally concatenate media path to media dir when loading from local disk."""
 | 
			
		||||
        if not isinstance(medias, list):
 | 
			
		||||
            medias = [medias] if medias is not None else []
 | 
			
		||||
        if medias is None:
 | 
			
		||||
            return None
 | 
			
		||||
        elif not isinstance(medias, list):
 | 
			
		||||
            medias = [medias]
 | 
			
		||||
        elif len(medias) == 0:
 | 
			
		||||
            return None
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,7 @@ import re
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union
 | 
			
		||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@ -68,9 +68,9 @@ if TYPE_CHECKING:
 | 
			
		||||
        path: Optional[str]
 | 
			
		||||
        bytes: Optional[bytes]
 | 
			
		||||
 | 
			
		||||
    ImageInput = Union[str, bytes, EncodedImage, ImageObject]
 | 
			
		||||
    VideoInput = str
 | 
			
		||||
    AudioInput = Union[str, NDArray]
 | 
			
		||||
    ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
 | 
			
		||||
    VideoInput = Union[str, BinaryIO]
 | 
			
		||||
    AudioInput = Union[str, BinaryIO, NDArray]
 | 
			
		||||
 | 
			
		||||
    class MMProcessor(ProcessorMixin):
 | 
			
		||||
        patch_size: int
 | 
			
		||||
@ -146,12 +146,6 @@ class MMPluginMixin:
 | 
			
		||||
        video_processor: BaseImageProcessor = getattr(
 | 
			
		||||
            processor, "video_processor", getattr(processor, "image_processor", None)
 | 
			
		||||
        )
 | 
			
		||||
        if image_processor is None and video_processor is None:  # hack for qwen2_5_omni
 | 
			
		||||
            image_processor, video_processor = (
 | 
			
		||||
                getattr(processor, "omni_processor", None),
 | 
			
		||||
                getattr(processor, "omni_processor", None),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
 | 
			
		||||
        if len(images) != 0 and self.image_token is None:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
@ -211,11 +205,11 @@ class MMPluginMixin:
 | 
			
		||||
        sample_frames = min(total_frames, video_maxlen, sample_frames)
 | 
			
		||||
        return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
 | 
			
		||||
 | 
			
		||||
    def _regularize_images(self, images: list["ImageInput"], **kwargs) -> list["ImageObject"]:
 | 
			
		||||
    def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
 | 
			
		||||
        r"""Regularize images to avoid error. Including reading and pre-processing."""
 | 
			
		||||
        results = []
 | 
			
		||||
        for image in images:
 | 
			
		||||
            if isinstance(image, str):
 | 
			
		||||
            if isinstance(image, (str, BinaryIO)):
 | 
			
		||||
                image = Image.open(image)
 | 
			
		||||
            elif isinstance(image, bytes):
 | 
			
		||||
                image = Image.open(BytesIO(image))
 | 
			
		||||
@ -230,9 +224,9 @@ class MMPluginMixin:
 | 
			
		||||
 | 
			
		||||
            results.append(self._preprocess_image(image, **kwargs))
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
        return {"images": results}
 | 
			
		||||
 | 
			
		||||
    def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
 | 
			
		||||
    def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
 | 
			
		||||
        r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
 | 
			
		||||
        results = []
 | 
			
		||||
        for video in videos:
 | 
			
		||||
@ -245,24 +239,27 @@ class MMPluginMixin:
 | 
			
		||||
                if frame_idx in sample_indices:
 | 
			
		||||
                    frames.append(frame.to_image())
 | 
			
		||||
 | 
			
		||||
            frames = self._regularize_images(frames, **kwargs)
 | 
			
		||||
            frames = self._regularize_images(frames, **kwargs)["images"]
 | 
			
		||||
            results.append(frames)
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
        return {"videos": results}
 | 
			
		||||
 | 
			
		||||
    def _regularize_audios(self, audios: list["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
 | 
			
		||||
    def _regularize_audios(
 | 
			
		||||
        self, audios: list["AudioInput"], sampling_rate: float, **kwargs
 | 
			
		||||
    ) -> dict[str, Union[list["NDArray"], list[float]]]:
 | 
			
		||||
        r"""Regularizes audios to avoid error. Including reading and resampling."""
 | 
			
		||||
        results = []
 | 
			
		||||
        results, sampling_rates = [], []
 | 
			
		||||
        for audio in audios:
 | 
			
		||||
            if isinstance(audio, str):
 | 
			
		||||
                audio = librosa.load(audio, sr=sampling_rate)[0]
 | 
			
		||||
            if isinstance(audio, (str, BinaryIO)):
 | 
			
		||||
                audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
 | 
			
		||||
 | 
			
		||||
            if not isinstance(audio, np.ndarray):
 | 
			
		||||
                raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.")
 | 
			
		||||
 | 
			
		||||
            results.append(audio)
 | 
			
		||||
            sampling_rates.append(sampling_rate)
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
        return {"audios": results, "sampling_rates": sampling_rates}
 | 
			
		||||
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
@ -298,8 +295,8 @@ class MMPluginMixin:
 | 
			
		||||
                images,
 | 
			
		||||
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
 | 
			
		||||
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
            )
 | 
			
		||||
            if imglens is not None:
 | 
			
		||||
            )["images"]
 | 
			
		||||
            if imglens is not None:  # if imglens are provided, make batched images
 | 
			
		||||
                images = _make_batched_images(images, imglens)
 | 
			
		||||
 | 
			
		||||
            image_processor_kwargs = {}
 | 
			
		||||
@ -325,7 +322,7 @@ class MMPluginMixin:
 | 
			
		||||
                image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
 | 
			
		||||
                video_fps=getattr(processor, "video_fps", 2.0),
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )
 | 
			
		||||
            )["videos"]
 | 
			
		||||
            if "videos" in inspect.signature(video_processor.preprocess).parameters:  # for qwen2_vl and video_llava
 | 
			
		||||
                mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
 | 
			
		||||
            else:  # for llava_next_video
 | 
			
		||||
@ -335,12 +332,12 @@ class MMPluginMixin:
 | 
			
		||||
            feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
 | 
			
		||||
            audios = self._regularize_audios(
 | 
			
		||||
                audios,
 | 
			
		||||
                sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
 | 
			
		||||
            )
 | 
			
		||||
                sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
 | 
			
		||||
            )["audios"]
 | 
			
		||||
            mm_inputs.update(
 | 
			
		||||
                feature_extractor(
 | 
			
		||||
                    audios,
 | 
			
		||||
                    sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
 | 
			
		||||
                    sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
 | 
			
		||||
                    return_attention_mask=True,
 | 
			
		||||
                    padding="max_length",
 | 
			
		||||
                    return_tensors="pt",
 | 
			
		||||
@ -726,14 +723,13 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor")
 | 
			
		||||
        feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
                images,
 | 
			
		||||
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
 | 
			
		||||
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
            )
 | 
			
		||||
            )["images"]
 | 
			
		||||
            if "valid_image_nums_ls" in kwargs:
 | 
			
		||||
                valid_image_nums_ls = kwargs["valid_image_nums_ls"]
 | 
			
		||||
                new_images = []
 | 
			
		||||
@ -756,15 +752,15 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
                image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
 | 
			
		||||
                video_fps=getattr(processor, "video_fps", 2.0),
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )
 | 
			
		||||
            )["videos"]
 | 
			
		||||
            video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
 | 
			
		||||
            mm_inputs.update(video_inputs)
 | 
			
		||||
 | 
			
		||||
        if len(audios) != 0:
 | 
			
		||||
            audios = self._regularize_audios(
 | 
			
		||||
                audios,
 | 
			
		||||
                sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
 | 
			
		||||
            )
 | 
			
		||||
                sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
 | 
			
		||||
            )["audios"]
 | 
			
		||||
            if "valid_audio_nums_ls" in kwargs:
 | 
			
		||||
                valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
 | 
			
		||||
                audios_ls = []
 | 
			
		||||
@ -778,7 +774,7 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
            audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
 | 
			
		||||
                audios_ls,
 | 
			
		||||
                chunk_input=True,
 | 
			
		||||
                sampling_rate=16000,
 | 
			
		||||
                sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
 | 
			
		||||
            )
 | 
			
		||||
            audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
 | 
			
		||||
            mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
 | 
			
		||||
@ -1110,195 +1106,6 @@ class Qwen2AudioPlugin(BasePlugin):
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Qwen2OmniPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
        imglens: Optional[list[int]] = None,
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            image_processor: BaseImageProcessor = getattr(processor, "omni_processor", None)  # FIXME
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
                images,
 | 
			
		||||
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
 | 
			
		||||
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
            )
 | 
			
		||||
            if imglens is not None:
 | 
			
		||||
                images = _make_batched_images(images, imglens)
 | 
			
		||||
 | 
			
		||||
            image_processor_kwargs = {}
 | 
			
		||||
            mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0:
 | 
			
		||||
            video_processor: BaseImageProcessor = getattr(
 | 
			
		||||
                processor, "video_processor", getattr(processor, "omni_processor", None)
 | 
			
		||||
            )
 | 
			
		||||
            videos = self._regularize_videos(
 | 
			
		||||
                videos,
 | 
			
		||||
                image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
 | 
			
		||||
                image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
 | 
			
		||||
                video_fps=getattr(processor, "video_fps", 2.0),
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )
 | 
			
		||||
            if "videos" in inspect.signature(video_processor.preprocess).parameters:  # for qwen2_vl and video_llava
 | 
			
		||||
                mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
 | 
			
		||||
                fps = [2.0] * len(videos)  # FIXME hardcode
 | 
			
		||||
                video_second_per_grid = [fps[i] / video_processor.temporal_patch_size for i in range(len(fps))]
 | 
			
		||||
                mm_inputs["video_second_per_grid"] = torch.tensor(video_second_per_grid)
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
        if len(audios) != 0:
 | 
			
		||||
            feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
 | 
			
		||||
            audios = self._regularize_audios(
 | 
			
		||||
                audios,
 | 
			
		||||
                sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs.update(
 | 
			
		||||
                feature_extractor(
 | 
			
		||||
                    audios,
 | 
			
		||||
                    sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
 | 
			
		||||
                    return_attention_mask=True,
 | 
			
		||||
                    padding="max_length",
 | 
			
		||||
                    return_tensors="pt",
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask")  # prevent conflicts
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        else:
 | 
			
		||||
            mm_inputs = {}
 | 
			
		||||
 | 
			
		||||
        num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
 | 
			
		||||
        use_audio_in_video = getattr(processor, "use_audio_in_video", False)
 | 
			
		||||
 | 
			
		||||
        # get length or size from mm_inputs
 | 
			
		||||
        if "feature_attention_mask" in mm_inputs:
 | 
			
		||||
            input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
 | 
			
		||||
            audio_lengths = (input_lengths - 2) // 2 + 1
 | 
			
		||||
 | 
			
		||||
        if mm_inputs.get("image_grid_thw", None) is not None:
 | 
			
		||||
            image_grid_thw = mm_inputs["image_grid_thw"]
 | 
			
		||||
            merge_length = processor.omni_processor.merge_size**2
 | 
			
		||||
 | 
			
		||||
        if mm_inputs.get("video_grid_thw", None) is not None:
 | 
			
		||||
            video_grid_thw = mm_inputs["video_grid_thw"]
 | 
			
		||||
            merge_length = processor.omni_processor.merge_size**2
 | 
			
		||||
 | 
			
		||||
        if use_audio_in_video:
 | 
			
		||||
            if audio_lengths is None:
 | 
			
		||||
                raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
 | 
			
		||||
 | 
			
		||||
            if not mm_inputs.get("video_grid_thw", None):
 | 
			
		||||
                raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
 | 
			
		||||
 | 
			
		||||
            positions_list = []
 | 
			
		||||
            for i, message in enumerate(messages):  # get multimodal index when use_audio
 | 
			
		||||
                positions = []
 | 
			
		||||
                for special_token in [self.audio_token, self.image_token, self.video_token]:
 | 
			
		||||
                    start = 0
 | 
			
		||||
                    while True:
 | 
			
		||||
                        pos = message[i].find(special_token, start)
 | 
			
		||||
                        if pos == -1:
 | 
			
		||||
                            break
 | 
			
		||||
                        positions.append((pos, special_token))
 | 
			
		||||
                        start = pos + len(special_token)
 | 
			
		||||
 | 
			
		||||
                positions_list.append(positions.sort(key=lambda x: x[0]))
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            # separate with audio-video
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    IMAGE_PLACEHOLDER,
 | 
			
		||||
                    f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
 | 
			
		||||
                    1,
 | 
			
		||||
                )
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            if not use_audio_in_video:
 | 
			
		||||
                while AUDIO_PLACEHOLDER in content:
 | 
			
		||||
                    audio_token_replace_length = audio_lengths[num_audio_tokens]
 | 
			
		||||
                    content = content.replace(
 | 
			
		||||
                        AUDIO_PLACEHOLDER,
 | 
			
		||||
                        f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
 | 
			
		||||
                        1,
 | 
			
		||||
                    )
 | 
			
		||||
                    num_audio_tokens += 1
 | 
			
		||||
                # TODO handle video_input and use_audio_in_video
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
 | 
			
		||||
                    content = content.replace(
 | 
			
		||||
                        VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
 | 
			
		||||
                    )
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
            else:  # if use the audio of video # deal video token and audio token togather
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
 | 
			
		||||
                    video_t_index = (
 | 
			
		||||
                        torch.arange(video_grid_thw[num_video_tokens][0])
 | 
			
		||||
                        .view(-1, 1, 1)
 | 
			
		||||
                        .expand(
 | 
			
		||||
                            -1,
 | 
			
		||||
                            video_grid_thw[num_video_tokens][1] // self.omni_processor.merge_size,
 | 
			
		||||
                            video_grid_thw[num_video_tokens][2] // self.omni_processor.merge_size,
 | 
			
		||||
                        )
 | 
			
		||||
                        .flatten()
 | 
			
		||||
                        * mm_inputs["video_second_per_grid"][num_video_tokens]
 | 
			
		||||
                        * 25  # FIXME hardcode of position_id_per_seconds=25
 | 
			
		||||
                    ).long()
 | 
			
		||||
                    t_ntoken_per_chunk = 50  # FIXME hardcode: [25 * 2]
 | 
			
		||||
                    video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
 | 
			
		||||
                    audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
 | 
			
		||||
                    placeholder_string = ""
 | 
			
		||||
                    for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
 | 
			
		||||
                        video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
 | 
			
		||||
                        audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
 | 
			
		||||
                        placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
 | 
			
		||||
                        if video_chunk_index is not None:
 | 
			
		||||
                            placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
 | 
			
		||||
                        if audio_chunk_index is not None:
 | 
			
		||||
                            placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
 | 
			
		||||
                        placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
 | 
			
		||||
                    content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
 | 
			
		||||
                    content = content.replace(AUDIO_PLACEHOLDER, "", 1)
 | 
			
		||||
                    num_audio_tokens += 1
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
        if len(audios) != num_audio_tokens:
 | 
			
		||||
            raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
        if len(videos) != num_video_tokens:
 | 
			
		||||
            raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
@ -1321,7 +1128,7 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def _regularize_videos(
 | 
			
		||||
        self, videos: list["VideoInput"], **kwargs
 | 
			
		||||
    ) -> tuple[list[list["ImageObject"]], list[float]]:
 | 
			
		||||
    ) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
 | 
			
		||||
        results, fps_per_video = [], []
 | 
			
		||||
        for video in videos:
 | 
			
		||||
            container = av.open(video, "r")
 | 
			
		||||
@ -1336,14 +1143,14 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
            if len(frames) % 2 != 0:  # qwen2-vl requires even number of frames
 | 
			
		||||
                frames.append(frames[-1])
 | 
			
		||||
 | 
			
		||||
            frames = self._regularize_images(frames, **kwargs)
 | 
			
		||||
            frames = self._regularize_images(frames, **kwargs)["images"]
 | 
			
		||||
            results.append(frames)
 | 
			
		||||
            if video_stream.duration is None:
 | 
			
		||||
                fps_per_video.append(2.0)
 | 
			
		||||
            else:
 | 
			
		||||
                fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
 | 
			
		||||
 | 
			
		||||
        return results, fps_per_video
 | 
			
		||||
        return {"videos": results, "fps_per_video": fps_per_video}
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
@ -1360,19 +1167,19 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
                images,
 | 
			
		||||
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
 | 
			
		||||
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
            )
 | 
			
		||||
            )["images"]
 | 
			
		||||
            mm_inputs.update(image_processor(images, return_tensors="pt"))
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0:
 | 
			
		||||
            videos, fps_per_video = self._regularize_videos(
 | 
			
		||||
            video_data = self._regularize_videos(
 | 
			
		||||
                videos,
 | 
			
		||||
                image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
 | 
			
		||||
                image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
 | 
			
		||||
                video_fps=getattr(processor, "video_fps", 2.0),
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt"))
 | 
			
		||||
            mm_inputs["fps_per_video"] = fps_per_video
 | 
			
		||||
            mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
 | 
			
		||||
            mm_inputs["fps_per_video"] = video_data["fps_per_video"]
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
@ -1454,6 +1261,186 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Qwen2OmniPlugin(Qwen2VLPlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
 | 
			
		||||
        feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
                images,
 | 
			
		||||
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
 | 
			
		||||
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
            )["images"]
 | 
			
		||||
            mm_inputs.update(image_processor(images, return_tensors="pt"))
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0:
 | 
			
		||||
            video_dict = self._regularize_videos(
 | 
			
		||||
                videos,
 | 
			
		||||
                image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
 | 
			
		||||
                image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
 | 
			
		||||
                video_fps=getattr(processor, "video_fps", 2.0),
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
 | 
			
		||||
            mm_inputs["fps_per_video"] = video_dict["fps_per_video"]
 | 
			
		||||
 | 
			
		||||
        if len(audios) != 0:
 | 
			
		||||
            audios = self._regularize_audios(
 | 
			
		||||
                audios,
 | 
			
		||||
                sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
 | 
			
		||||
            )["audios"]
 | 
			
		||||
            mm_inputs.update(
 | 
			
		||||
                feature_extractor(
 | 
			
		||||
                    audios,
 | 
			
		||||
                    sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
 | 
			
		||||
                    return_attention_mask=True,
 | 
			
		||||
                    padding="max_length",
 | 
			
		||||
                    return_tensors="pt",
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask")  # prevent conflicts
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        else:
 | 
			
		||||
            mm_inputs = {}
 | 
			
		||||
 | 
			
		||||
        num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
 | 
			
		||||
        use_audio_in_video = getattr(processor, "use_audio_in_video", False)
 | 
			
		||||
 | 
			
		||||
        # get length or size from mm_inputs
 | 
			
		||||
        if "feature_attention_mask" in mm_inputs:
 | 
			
		||||
            input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
 | 
			
		||||
            audio_lengths = (input_lengths - 2) // 2 + 1
 | 
			
		||||
 | 
			
		||||
        if mm_inputs.get("image_grid_thw", None) is not None:
 | 
			
		||||
            image_grid_thw = mm_inputs["image_grid_thw"]
 | 
			
		||||
            merge_length = processor.image_processor.merge_size**2
 | 
			
		||||
 | 
			
		||||
        if mm_inputs.get("video_grid_thw", None) is not None:
 | 
			
		||||
            video_grid_thw = mm_inputs["video_grid_thw"]
 | 
			
		||||
            merge_length = processor.image_processor.merge_size**2
 | 
			
		||||
 | 
			
		||||
        if use_audio_in_video:
 | 
			
		||||
            if audio_lengths is None:
 | 
			
		||||
                raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
 | 
			
		||||
 | 
			
		||||
            if not mm_inputs.get("video_grid_thw", None):
 | 
			
		||||
                raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
 | 
			
		||||
 | 
			
		||||
            positions_list = []
 | 
			
		||||
            for i, message in enumerate(messages):  # get multimodal index when use_audio
 | 
			
		||||
                positions = []
 | 
			
		||||
                for special_token in [self.audio_token, self.image_token, self.video_token]:
 | 
			
		||||
                    start = 0
 | 
			
		||||
                    while True:
 | 
			
		||||
                        pos = message[i].find(special_token, start)
 | 
			
		||||
                        if pos == -1:
 | 
			
		||||
                            break
 | 
			
		||||
                        positions.append((pos, special_token))
 | 
			
		||||
                        start = pos + len(special_token)
 | 
			
		||||
 | 
			
		||||
                positions_list.append(positions.sort(key=lambda x: x[0]))
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            # separate with audio-video
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    IMAGE_PLACEHOLDER,
 | 
			
		||||
                    f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
 | 
			
		||||
                    1,
 | 
			
		||||
                )
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            if not use_audio_in_video:
 | 
			
		||||
                while AUDIO_PLACEHOLDER in content:
 | 
			
		||||
                    audio_token_replace_length = audio_lengths[num_audio_tokens]
 | 
			
		||||
                    content = content.replace(
 | 
			
		||||
                        AUDIO_PLACEHOLDER,
 | 
			
		||||
                        f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
 | 
			
		||||
                        1,
 | 
			
		||||
                    )
 | 
			
		||||
                    num_audio_tokens += 1
 | 
			
		||||
 | 
			
		||||
                # TODO handle video_input and use_audio_in_video
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
 | 
			
		||||
                    content = content.replace(
 | 
			
		||||
                        VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
 | 
			
		||||
                    )
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
            else:  # if use the audio of video # deal video token and audio token togather
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
 | 
			
		||||
                    video_t_index = (
 | 
			
		||||
                        torch.arange(video_grid_thw[num_video_tokens][0])
 | 
			
		||||
                        .view(-1, 1, 1)
 | 
			
		||||
                        .expand(
 | 
			
		||||
                            -1,
 | 
			
		||||
                            video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size,
 | 
			
		||||
                            video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
 | 
			
		||||
                        )
 | 
			
		||||
                        .flatten()
 | 
			
		||||
                        * mm_inputs["video_second_per_grid"][num_video_tokens]
 | 
			
		||||
                        * 25  # FIXME hardcode of position_id_per_seconds=25
 | 
			
		||||
                    ).long()
 | 
			
		||||
                    t_ntoken_per_chunk = 50  # FIXME hardcode: [25 * 2]
 | 
			
		||||
                    video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
 | 
			
		||||
                    audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
 | 
			
		||||
                    placeholder_string = ""
 | 
			
		||||
                    for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
 | 
			
		||||
                        video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
 | 
			
		||||
                        audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
 | 
			
		||||
                        placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
 | 
			
		||||
                        if video_chunk_index is not None:
 | 
			
		||||
                            placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
 | 
			
		||||
                        if audio_chunk_index is not None:
 | 
			
		||||
                            placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
 | 
			
		||||
                        placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
 | 
			
		||||
 | 
			
		||||
                    content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
 | 
			
		||||
                    content = content.replace(AUDIO_PLACEHOLDER, "", 1)
 | 
			
		||||
                    num_audio_tokens += 1
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
        if len(audios) != num_audio_tokens:
 | 
			
		||||
            raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        if len(videos) != num_video_tokens:
 | 
			
		||||
            raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class VideoLlavaPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
 | 
			
		||||
@ -242,6 +242,10 @@ class ProcessorArguments:
 | 
			
		||||
        default=128,
 | 
			
		||||
        metadata={"help": "The maximum number of sampled frames for video inputs."},
 | 
			
		||||
    )
 | 
			
		||||
    audio_sampling_rate: int = field(
 | 
			
		||||
        default=16000,
 | 
			
		||||
        metadata={"help": "The sampling rate of audio inputs."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        if self.image_max_pixels < self.image_min_pixels:
 | 
			
		||||
 | 
			
		||||
@ -262,9 +262,7 @@ _register_composite_model(
 | 
			
		||||
    projector_key="visual.merger",
 | 
			
		||||
    vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
 | 
			
		||||
    language_model_keys=["model", "lm_head"],
 | 
			
		||||
    lora_conflict_keys=[
 | 
			
		||||
        "patch_embed",
 | 
			
		||||
    ],
 | 
			
		||||
    lora_conflict_keys=["patch_embed"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -78,6 +78,7 @@ def patch_processor(
 | 
			
		||||
    setattr(processor, "video_min_pixels", model_args.video_min_pixels)
 | 
			
		||||
    setattr(processor, "video_fps", model_args.video_fps)
 | 
			
		||||
    setattr(processor, "video_maxlen", model_args.video_maxlen)
 | 
			
		||||
    setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_config(
 | 
			
		||||
@ -123,15 +124,13 @@ def patch_config(
 | 
			
		||||
    # deepspeed zero3 is not compatible with low_cpu_mem_usage
 | 
			
		||||
    init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
 | 
			
		||||
 | 
			
		||||
    # cast data type of the model if:
 | 
			
		||||
    # 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
 | 
			
		||||
    # 2. quantization_bit is not None (qlora)
 | 
			
		||||
    if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None:
 | 
			
		||||
    # do not cast data type of the model deepspeed zero3 without qlora
 | 
			
		||||
    if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
 | 
			
		||||
        init_kwargs["torch_dtype"] = model_args.compute_dtype
 | 
			
		||||
 | 
			
		||||
        if init_kwargs["low_cpu_mem_usage"]:  # device map requires low_cpu_mem_usage=True
 | 
			
		||||
        if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled():  # fsdp does not need device map
 | 
			
		||||
            if "device_map" not in init_kwargs and model_args.device_map:
 | 
			
		||||
                init_kwargs["device_map"] = model_args.device_map
 | 
			
		||||
                init_kwargs["device_map"] = model_args.device_map  # device map requires low_cpu_mem_usage=True
 | 
			
		||||
 | 
			
		||||
            if init_kwargs.get("device_map", None) == "auto":
 | 
			
		||||
                init_kwargs["offload_folder"] = model_args.offload_folder
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user