From 903db09822b167dbe2d361aadb521e3c6caaad7a Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 2 Apr 2025 02:27:04 +0800 Subject: [PATCH] [infer] vllm video/audio inference (#7566) --- scripts/vllm_infer.py | 16 +- src/llamafactory/api/chat.py | 49 +- src/llamafactory/api/protocol.py | 9 +- src/llamafactory/chat/sglang_engine.py | 23 +- src/llamafactory/chat/vllm_engine.py | 41 +- src/llamafactory/data/converter.py | 12 +- src/llamafactory/data/mm_plugin.py | 445 +++++++++---------- src/llamafactory/hparams/model_args.py | 4 + src/llamafactory/model/model_utils/visual.py | 4 +- src/llamafactory/model/patcher.py | 11 +- 10 files changed, 329 insertions(+), 285 deletions(-) diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index dceb1d31..ce17adfb 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -92,8 +92,20 @@ def vllm_infer( multi_modal_data = { "image": template_obj.mm_plugin._regularize_images( sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels - ) + )["images"] } + elif sample["videos"]: + multi_modal_data = { + "video": template_obj.mm_plugin._regularize_videos( + sample["videos"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels + )["videos"] + } + elif sample["audios"]: + audio_data = template_obj.mm_plugin._regularize_audios( + sample["audios"], + sampling_rate=16000, + ) + multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} else: multi_modal_data = None @@ -131,7 +143,7 @@ def vllm_infer( "enable_lora": model_args.adapter_name_or_path is not None, } if template_obj.mm_plugin.__class__.__name__ != "BasePlugin": - engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2} + engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2} if isinstance(model_args.vllm_config, dict): engine_args.update(model_args.vllm_config) diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index ed40e8f8..8340ccd4 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Optional from ..data import Role as DataRole from ..extras import logging -from ..extras.constants import IMAGE_PLACEHOLDER +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.misc import is_env_enabled from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from .common import dictify, jsonify @@ -56,7 +56,7 @@ if is_requests_available(): if TYPE_CHECKING: from ..chat import ChatModel - from ..data.mm_plugin import ImageInput + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput from .protocol import ChatCompletionRequest, ScoreEvaluationRequest @@ -72,7 +72,14 @@ ROLE_MAPPING = { def _process_request( request: "ChatCompletionRequest", -) -> tuple[list[dict[str, str]], Optional[str], Optional[str], Optional[list["ImageInput"]]]: +) -> tuple[ + list[dict[str, str]], + Optional[str], + Optional[str], + Optional[list["ImageInput"]], + Optional[list["VideoInput"]], + Optional[list["AudioInput"]], +]: if is_env_enabled("API_VERBOSE", "1"): logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") @@ -88,7 +95,7 @@ def _process_request( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") input_messages = [] - images = [] + images, videos, audios = [], [], [] for i, message in enumerate(request.messages): if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") @@ -107,7 +114,7 @@ def _process_request( for input_item in message.content: if input_item.type == "text": text_content += input_item.text - else: + elif input_item.type == "image_url": text_content += IMAGE_PLACEHOLDER image_url = input_item.image_url.url if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image @@ -118,6 +125,28 @@ def _process_request( image_stream = requests.get(image_url, stream=True).raw images.append(Image.open(image_stream).convert("RGB")) + elif input_item.type == "video_url": + text_content += VIDEO_PLACEHOLDER + video_url = input_item.video_url.url + if os.path.isfile(video_url): # local file + video_stream = open(video_url, "rb") + else: # web uri + video_stream = requests.get(video_url, stream=True).raw + + videos.append(video_stream) + elif input_item.type == "audio_url": + text_content += AUDIO_PLACEHOLDER + audio_url = input_item.audio_url.url + if os.path.isfile(audio_url): # local file + audio_stream = open(audio_url, "rb") + else: # web uri + audio_stream = requests.get(audio_url, stream=True).raw + + audios.append(audio_stream) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}." + ) input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content}) else: @@ -132,7 +161,7 @@ def _process_request( else: tools = None - return input_messages, system, tools, images or None + return input_messages, system, tools, images or None, videos or None, audios or None def _create_stream_chat_completion_chunk( @@ -151,12 +180,14 @@ async def create_chat_completion_response( request: "ChatCompletionRequest", chat_model: "ChatModel" ) -> "ChatCompletionResponse": completion_id = f"chatcmpl-{uuid.uuid4().hex}" - input_messages, system, tools, images = _process_request(request) + input_messages, system, tools, images, videos, audios = _process_request(request) responses = await chat_model.achat( input_messages, system, tools, images, + videos, + audios, do_sample=request.do_sample, temperature=request.temperature, top_p=request.top_p, @@ -202,7 +233,7 @@ async def create_stream_chat_completion_response( request: "ChatCompletionRequest", chat_model: "ChatModel" ) -> AsyncGenerator[str, None]: completion_id = f"chatcmpl-{uuid.uuid4().hex}" - input_messages, system, tools, images = _process_request(request) + input_messages, system, tools, images, videos, audios = _process_request(request) if tools: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") @@ -217,6 +248,8 @@ async def create_stream_chat_completion_response( system, tools, images, + videos, + audios, do_sample=request.do_sample, temperature=request.temperature, top_p=request.top_p, diff --git a/src/llamafactory/api/protocol.py b/src/llamafactory/api/protocol.py index bb3029d5..ac9746ef 100644 --- a/src/llamafactory/api/protocol.py +++ b/src/llamafactory/api/protocol.py @@ -70,14 +70,17 @@ class FunctionCall(BaseModel): function: Function -class ImageURL(BaseModel): +class URL(BaseModel): url: str + detail: Literal["auto", "low", "high"] = "auto" class MultimodalInputItem(BaseModel): - type: Literal["text", "image_url"] + type: Literal["text", "image_url", "video_url", "audio_url"] text: Optional[str] = None - image_url: Optional[ImageURL] = None + image_url: Optional[URL] = None + video_url: Optional[URL] = None + audio_url: Optional[URL] = None class ChatMessage(BaseModel): diff --git a/src/llamafactory/chat/sglang_engine.py b/src/llamafactory/chat/sglang_engine.py index 7bcb05d2..3fc3aeb5 100644 --- a/src/llamafactory/chat/sglang_engine.py +++ b/src/llamafactory/chat/sglang_engine.py @@ -33,7 +33,7 @@ from .base_engine import BaseEngine, Response if is_sglang_available(): - from sglang.utils import launch_server_cmd, terminate_process, wait_for_server + from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore if TYPE_CHECKING: @@ -134,24 +134,17 @@ class SGLangEngine(BaseEngine): audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> AsyncIterator[dict[str, Any]]: - mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} - if images is not None: - mm_input_dict.update({"images": images, "imglens": [len(images)]}) - if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): - messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] + if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] - if videos is not None: - mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]}) - if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): - messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] + if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] - if audios is not None: - mm_input_dict.update({"audios": audios, "audlens": [len(audios)]}) - if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): - messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] + if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] messages = self.template.mm_plugin.process_messages( - messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor + messages, images or [], videos or [], audios or [], self.processor ) paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or self.generating_args["default_system"] diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index ef2405bc..1100fc8a 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -83,7 +83,7 @@ class VllmEngine(BaseEngine): "max_lora_rank": model_args.vllm_max_lora_rank, } if self.template.mm_plugin.__class__.__name__ != "BasePlugin": - engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2} + engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2} if isinstance(model_args.vllm_config, dict): engine_args.update(model_args.vllm_config) @@ -111,24 +111,17 @@ class VllmEngine(BaseEngine): **input_kwargs, ) -> AsyncIterator["RequestOutput"]: request_id = f"chatcmpl-{uuid.uuid4().hex}" - mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} - if images is not None: - mm_input_dict.update({"images": images, "imglens": [len(images)]}) - if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): - messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] + if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] - if videos is not None: - mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]}) - if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): - messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] + if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] - if audios is not None: - mm_input_dict.update({"audios": audios, "audlens": [len(audios)]}) - if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): - messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] + if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] messages = self.template.mm_plugin.process_messages( - messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor + messages, images or [], videos or [], audios or [], self.processor ) paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or self.generating_args["default_system"] @@ -186,8 +179,24 @@ class VllmEngine(BaseEngine): images, image_max_pixels=self.model_args.image_max_pixels, image_min_pixels=self.model_args.image_min_pixels, - ) + )["images"] } + elif videos is not None: + multi_modal_data = { + "video": self.template.mm_plugin._regularize_videos( + videos, + image_max_pixels=self.model_args.video_max_pixels, + image_min_pixels=self.model_args.video_min_pixels, + video_fps=self.model_args.video_fps, + video_maxlen=self.model_args.video_maxlen, + )["videos"] + } + elif audios is not None: + audio_data = self.template.mm_plugin._regularize_audios( + audios, + sampling_rate=self.model_args.audio_sampling_rate, + ) + multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} else: multi_modal_data = None diff --git a/src/llamafactory/data/converter.py b/src/llamafactory/data/converter.py index 25f39545..f0c791b7 100644 --- a/src/llamafactory/data/converter.py +++ b/src/llamafactory/data/converter.py @@ -26,8 +26,12 @@ if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments from ..hparams import DataArguments + from .mm_plugin import AudioInput, ImageInput, VideoInput from .parser import DatasetAttr + MediaType = Union[ImageInput, VideoInput, AudioInput] + + logger = logging.get_logger(__name__) @@ -36,10 +40,12 @@ class DatasetConverter: dataset_attr: "DatasetAttr" data_args: "DataArguments" - def _find_medias(self, medias: Union[Any, list[Any]]) -> Optional[list[Any]]: + def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]: r"""Optionally concatenate media path to media dir when loading from local disk.""" - if not isinstance(medias, list): - medias = [medias] if medias is not None else [] + if medias is None: + return None + elif not isinstance(medias, list): + medias = [medias] elif len(medias) == 0: return None else: diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index b9180db2..85ce0bd3 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -21,7 +21,7 @@ import re from copy import deepcopy from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union +from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union import numpy as np import torch @@ -68,9 +68,9 @@ if TYPE_CHECKING: path: Optional[str] bytes: Optional[bytes] - ImageInput = Union[str, bytes, EncodedImage, ImageObject] - VideoInput = str - AudioInput = Union[str, NDArray] + ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject] + VideoInput = Union[str, BinaryIO] + AudioInput = Union[str, BinaryIO, NDArray] class MMProcessor(ProcessorMixin): patch_size: int @@ -146,12 +146,6 @@ class MMPluginMixin: video_processor: BaseImageProcessor = getattr( processor, "video_processor", getattr(processor, "image_processor", None) ) - if image_processor is None and video_processor is None: # hack for qwen2_5_omni - image_processor, video_processor = ( - getattr(processor, "omni_processor", None), - getattr(processor, "omni_processor", None), - ) - feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) if len(images) != 0 and self.image_token is None: raise ValueError( @@ -211,11 +205,11 @@ class MMPluginMixin: sample_frames = min(total_frames, video_maxlen, sample_frames) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) - def _regularize_images(self, images: list["ImageInput"], **kwargs) -> list["ImageObject"]: + def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]: r"""Regularize images to avoid error. Including reading and pre-processing.""" results = [] for image in images: - if isinstance(image, str): + if isinstance(image, (str, BinaryIO)): image = Image.open(image) elif isinstance(image, bytes): image = Image.open(BytesIO(image)) @@ -230,9 +224,9 @@ class MMPluginMixin: results.append(self._preprocess_image(image, **kwargs)) - return results + return {"images": results} - def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> list[list["ImageObject"]]: + def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]: r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" results = [] for video in videos: @@ -245,24 +239,27 @@ class MMPluginMixin: if frame_idx in sample_indices: frames.append(frame.to_image()) - frames = self._regularize_images(frames, **kwargs) + frames = self._regularize_images(frames, **kwargs)["images"] results.append(frames) - return results + return {"videos": results} - def _regularize_audios(self, audios: list["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]: + def _regularize_audios( + self, audios: list["AudioInput"], sampling_rate: float, **kwargs + ) -> dict[str, Union[list["NDArray"], list[float]]]: r"""Regularizes audios to avoid error. Including reading and resampling.""" - results = [] + results, sampling_rates = [], [] for audio in audios: - if isinstance(audio, str): - audio = librosa.load(audio, sr=sampling_rate)[0] + if isinstance(audio, (str, BinaryIO)): + audio, sampling_rate = librosa.load(audio, sr=sampling_rate) if not isinstance(audio, np.ndarray): raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.") results.append(audio) + sampling_rates.append(sampling_rate) - return results + return {"audios": results, "sampling_rates": sampling_rates} def _get_mm_inputs( self, @@ -298,8 +295,8 @@ class MMPluginMixin: images, image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - ) - if imglens is not None: + )["images"] + if imglens is not None: # if imglens are provided, make batched images images = _make_batched_images(images, imglens) image_processor_kwargs = {} @@ -325,7 +322,7 @@ class MMPluginMixin: image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), - ) + )["videos"] if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt")) else: # for llava_next_video @@ -335,12 +332,12 @@ class MMPluginMixin: feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) audios = self._regularize_audios( audios, - sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), - ) + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + )["audios"] mm_inputs.update( feature_extractor( audios, - sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), return_attention_mask=True, padding="max_length", return_tensors="pt", @@ -726,14 +723,13 @@ class MiniCPMVPlugin(BasePlugin): **kwargs, ) -> dict[str, "torch.Tensor"]: image_processor: BaseImageProcessor = getattr(processor, "image_processor") - feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) mm_inputs = {} if len(images) != 0: images = self._regularize_images( images, image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - ) + )["images"] if "valid_image_nums_ls" in kwargs: valid_image_nums_ls = kwargs["valid_image_nums_ls"] new_images = [] @@ -756,15 +752,15 @@ class MiniCPMVPlugin(BasePlugin): image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), - ) + )["videos"] video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") mm_inputs.update(video_inputs) if len(audios) != 0: audios = self._regularize_audios( audios, - sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), - ) + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + )["audios"] if "valid_audio_nums_ls" in kwargs: valid_audio_nums_ls = kwargs["valid_audio_nums_ls"] audios_ls = [] @@ -778,7 +774,7 @@ class MiniCPMVPlugin(BasePlugin): audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract( audios_ls, chunk_input=True, - sampling_rate=16000, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), ) audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens] mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) @@ -1110,195 +1106,6 @@ class Qwen2AudioPlugin(BasePlugin): return self._get_mm_inputs(images, videos, audios, processor) -class Qwen2OmniPlugin(BasePlugin): - @override - def _get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: "MMProcessor", - imglens: Optional[list[int]] = None, - ) -> dict[str, "torch.Tensor"]: - mm_inputs = {} - if len(images) != 0: - image_processor: BaseImageProcessor = getattr(processor, "omni_processor", None) # FIXME - images = self._regularize_images( - images, - image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), - image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - ) - if imglens is not None: - images = _make_batched_images(images, imglens) - - image_processor_kwargs = {} - mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs)) - - if len(videos) != 0: - video_processor: BaseImageProcessor = getattr( - processor, "video_processor", getattr(processor, "omni_processor", None) - ) - videos = self._regularize_videos( - videos, - image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), - image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), - video_fps=getattr(processor, "video_fps", 2.0), - video_maxlen=getattr(processor, "video_maxlen", 128), - ) - if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava - mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt")) - fps = [2.0] * len(videos) # FIXME hardcode - video_second_per_grid = [fps[i] / video_processor.temporal_patch_size for i in range(len(fps))] - mm_inputs["video_second_per_grid"] = torch.tensor(video_second_per_grid) - - else: - raise NotImplementedError - - if len(audios) != 0: - feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) - audios = self._regularize_audios( - audios, - sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), - ) - mm_inputs.update( - feature_extractor( - audios, - sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), - return_attention_mask=True, - padding="max_length", - return_tensors="pt", - ) - ) - mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts - - return mm_inputs - - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - messages = deepcopy(messages) - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - else: - mm_inputs = {} - - num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0 - use_audio_in_video = getattr(processor, "use_audio_in_video", False) - - # get length or size from mm_inputs - if "feature_attention_mask" in mm_inputs: - input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 - audio_lengths = (input_lengths - 2) // 2 + 1 - - if mm_inputs.get("image_grid_thw", None) is not None: - image_grid_thw = mm_inputs["image_grid_thw"] - merge_length = processor.omni_processor.merge_size**2 - - if mm_inputs.get("video_grid_thw", None) is not None: - video_grid_thw = mm_inputs["video_grid_thw"] - merge_length = processor.omni_processor.merge_size**2 - - if use_audio_in_video: - if audio_lengths is None: - raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.") - - if not mm_inputs.get("video_grid_thw", None): - raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.") - - positions_list = [] - for i, message in enumerate(messages): # get multimodal index when use_audio - positions = [] - for special_token in [self.audio_token, self.image_token, self.video_token]: - start = 0 - while True: - pos = message[i].find(special_token, start) - if pos == -1: - break - positions.append((pos, special_token)) - start = pos + len(special_token) - - positions_list.append(positions.sort(key=lambda x: x[0])) - - for message in messages: - content = message["content"] - # separate with audio-video - while IMAGE_PLACEHOLDER in content: - image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length - content = content.replace( - IMAGE_PLACEHOLDER, - f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>", - 1, - ) - num_image_tokens += 1 - - if not use_audio_in_video: - while AUDIO_PLACEHOLDER in content: - audio_token_replace_length = audio_lengths[num_audio_tokens] - content = content.replace( - AUDIO_PLACEHOLDER, - f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>", - 1, - ) - num_audio_tokens += 1 - # TODO handle video_input and use_audio_in_video - while VIDEO_PLACEHOLDER in content: - video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length - content = content.replace( - VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1 - ) - num_video_tokens += 1 - else: # if use the audio of video # deal video token and audio token togather - while VIDEO_PLACEHOLDER in content: - audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) - video_t_index = ( - torch.arange(video_grid_thw[num_video_tokens][0]) - .view(-1, 1, 1) - .expand( - -1, - video_grid_thw[num_video_tokens][1] // self.omni_processor.merge_size, - video_grid_thw[num_video_tokens][2] // self.omni_processor.merge_size, - ) - .flatten() - * mm_inputs["video_second_per_grid"][num_video_tokens] - * 25 # FIXME hardcode of position_id_per_seconds=25 - ).long() - t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] - video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) - audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk) - placeholder_string = "" - for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))): - video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None - audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None - placeholder_string = "<|vision_bos|>" + "<|audio_bos|>" - if video_chunk_index is not None: - placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0]) - if audio_chunk_index is not None: - placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0]) - placeholder_string += "<|audio_eos|>" + "<|vision_eos|>" - content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1) - content = content.replace(AUDIO_PLACEHOLDER, "", 1) - num_audio_tokens += 1 - num_video_tokens += 1 - - message["content"] = content - - if len(audios) != num_audio_tokens: - raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.") - if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") - if len(videos) != num_video_tokens: - raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") - - return messages - - @dataclass class Qwen2VLPlugin(BasePlugin): @override @@ -1321,7 +1128,7 @@ class Qwen2VLPlugin(BasePlugin): @override def _regularize_videos( self, videos: list["VideoInput"], **kwargs - ) -> tuple[list[list["ImageObject"]], list[float]]: + ) -> dict[str, Union[list[list["ImageObject"]], list[float]]]: results, fps_per_video = [], [] for video in videos: container = av.open(video, "r") @@ -1336,14 +1143,14 @@ class Qwen2VLPlugin(BasePlugin): if len(frames) % 2 != 0: # qwen2-vl requires even number of frames frames.append(frames[-1]) - frames = self._regularize_images(frames, **kwargs) + frames = self._regularize_images(frames, **kwargs)["images"] results.append(frames) if video_stream.duration is None: fps_per_video.append(2.0) else: fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) - return results, fps_per_video + return {"videos": results, "fps_per_video": fps_per_video} @override def _get_mm_inputs( @@ -1360,19 +1167,19 @@ class Qwen2VLPlugin(BasePlugin): images, image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - ) + )["images"] mm_inputs.update(image_processor(images, return_tensors="pt")) if len(videos) != 0: - videos, fps_per_video = self._regularize_videos( + video_data = self._regularize_videos( videos, image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), ) - mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt")) - mm_inputs["fps_per_video"] = fps_per_video + mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt")) + mm_inputs["fps_per_video"] = video_data["fps_per_video"] return mm_inputs @@ -1454,6 +1261,186 @@ class Qwen2VLPlugin(BasePlugin): return mm_inputs +class Qwen2OmniPlugin(Qwen2VLPlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + video_dict = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt")) + mm_inputs["fps_per_video"] = video_dict["fps_per_video"] + + if len(audios) != 0: + audios = self._regularize_audios( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + )["audios"] + mm_inputs.update( + feature_extractor( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + return_attention_mask=True, + padding="max_length", + return_tensors="pt", + ) + ) + mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + else: + mm_inputs = {} + + num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0 + use_audio_in_video = getattr(processor, "use_audio_in_video", False) + + # get length or size from mm_inputs + if "feature_attention_mask" in mm_inputs: + input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 + audio_lengths = (input_lengths - 2) // 2 + 1 + + if mm_inputs.get("image_grid_thw", None) is not None: + image_grid_thw = mm_inputs["image_grid_thw"] + merge_length = processor.image_processor.merge_size**2 + + if mm_inputs.get("video_grid_thw", None) is not None: + video_grid_thw = mm_inputs["video_grid_thw"] + merge_length = processor.image_processor.merge_size**2 + + if use_audio_in_video: + if audio_lengths is None: + raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.") + + if not mm_inputs.get("video_grid_thw", None): + raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.") + + positions_list = [] + for i, message in enumerate(messages): # get multimodal index when use_audio + positions = [] + for special_token in [self.audio_token, self.image_token, self.video_token]: + start = 0 + while True: + pos = message[i].find(special_token, start) + if pos == -1: + break + positions.append((pos, special_token)) + start = pos + len(special_token) + + positions_list.append(positions.sort(key=lambda x: x[0])) + + for message in messages: + content = message["content"] + # separate with audio-video + while IMAGE_PLACEHOLDER in content: + image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length + content = content.replace( + IMAGE_PLACEHOLDER, + f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>", + 1, + ) + num_image_tokens += 1 + + if not use_audio_in_video: + while AUDIO_PLACEHOLDER in content: + audio_token_replace_length = audio_lengths[num_audio_tokens] + content = content.replace( + AUDIO_PLACEHOLDER, + f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>", + 1, + ) + num_audio_tokens += 1 + + # TODO handle video_input and use_audio_in_video + while VIDEO_PLACEHOLDER in content: + video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length + content = content.replace( + VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1 + ) + num_video_tokens += 1 + + else: # if use the audio of video # deal video token and audio token togather + while VIDEO_PLACEHOLDER in content: + audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) + video_t_index = ( + torch.arange(video_grid_thw[num_video_tokens][0]) + .view(-1, 1, 1) + .expand( + -1, + video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size, + video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size, + ) + .flatten() + * mm_inputs["video_second_per_grid"][num_video_tokens] + * 25 # FIXME hardcode of position_id_per_seconds=25 + ).long() + t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] + video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) + audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk) + placeholder_string = "" + for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))): + video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None + audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None + placeholder_string = "<|vision_bos|>" + "<|audio_bos|>" + if video_chunk_index is not None: + placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0]) + if audio_chunk_index is not None: + placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0]) + placeholder_string += "<|audio_eos|>" + "<|vision_eos|>" + + content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1) + content = content.replace(AUDIO_PLACEHOLDER, "", 1) + num_audio_tokens += 1 + num_video_tokens += 1 + + message["content"] = content + + if len(audios) != num_audio_tokens: + raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.") + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + if len(videos) != num_video_tokens: + raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") + + return messages + + @dataclass class VideoLlavaPlugin(BasePlugin): @override diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index fec05374..a9f61289 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -242,6 +242,10 @@ class ProcessorArguments: default=128, metadata={"help": "The maximum number of sampled frames for video inputs."}, ) + audio_sampling_rate: int = field( + default=16000, + metadata={"help": "The sampling rate of audio inputs."}, + ) def __post_init__(self): if self.image_max_pixels < self.image_min_pixels: diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 0158a50b..c69bc690 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -262,9 +262,7 @@ _register_composite_model( projector_key="visual.merger", vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"], language_model_keys=["model", "lm_head"], - lora_conflict_keys=[ - "patch_embed", - ], + lora_conflict_keys=["patch_embed"], ) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 6b690f40..28cb599f 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -78,6 +78,7 @@ def patch_processor( setattr(processor, "video_min_pixels", model_args.video_min_pixels) setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_maxlen", model_args.video_maxlen) + setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate) def patch_config( @@ -123,15 +124,13 @@ def patch_config( # deepspeed zero3 is not compatible with low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) - # cast data type of the model if: - # 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32) - # 2. quantization_bit is not None (qlora) - if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None: + # do not cast data type of the model deepspeed zero3 without qlora + if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None): init_kwargs["torch_dtype"] = model_args.compute_dtype - if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True + if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map if "device_map" not in init_kwargs and model_args.device_map: - init_kwargs["device_map"] = model_args.device_map + init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True if init_kwargs.get("device_map", None) == "auto": init_kwargs["offload_folder"] = model_args.offload_folder