Merge pull request #5365 from hiyouga/video_finetuning

Support Qwen2-VL Fine-Tuning on Video Datasets

Former-commit-id: 46b1765d0374bfc93d6a3af8669af1c2307814a7
This commit is contained in:
hoshi-hiyouga 2024-09-05 02:24:58 +08:00 committed by GitHub
commit ce77a89d8c
27 changed files with 409 additions and 148 deletions

View File

@ -38,6 +38,20 @@
"assistant_tag": "assistant" "assistant_tag": "assistant"
} }
}, },
"mllm_video_demo": {
"file_name": "mllm_video_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"videos": "videos"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"alpaca_en": { "alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en", "hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en" "ms_hub_url": "llamafactory/alpaca_en"

BIN
data/mllm_demo_data/1.mp4 Normal file

Binary file not shown.

BIN
data/mllm_demo_data/2.avi Normal file

Binary file not shown.

BIN
data/mllm_demo_data/3.mp4 Normal file

Binary file not shown.

47
data/mllm_video_demo.json Normal file
View File

@ -0,0 +1,47 @@
[
{
"messages": [
{
"content": "<video>Why is this video funny?",
"role": "user"
},
{
"content": "Because a baby is reading, and he is so cute!",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/1.mp4"
]
},
{
"messages": [
{
"content": "<video>What is she doing?",
"role": "user"
},
{
"content": "She is cooking.",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/2.avi"
]
},
{
"messages": [
{
"content": "<video>What's in the video?",
"role": "user"
},
{
"content": "A baby is playing in the living room.",
"role": "assistant"
}
],
"videos": [
"mllm_demo_data/3.mp4"
]
}
]

View File

@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
from ..data import Template from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@ -56,7 +56,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ... ) -> List["Response"]: ...
@ -66,7 +67,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ... ) -> AsyncGenerator[str, None]: ...

View File

@ -27,8 +27,7 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image from ..data.mm_plugin import ImageInput, VideoInput
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@ -56,10 +55,13 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop) task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
)
return task.result() return task.result()
async def achat( async def achat(
@ -67,20 +69,22 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
return await self.engine.chat(messages, system, tools, image, **input_kwargs) return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, image, **input_kwargs) generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
while True: while True:
try: try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@ -93,10 +97,11 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs): async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
yield new_token yield new_token
def get_scores( def get_scores(

View File

@ -22,7 +22,7 @@ import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
@ -30,11 +30,11 @@ from .base_engine import BaseEngine, Response
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from trl import PreTrainedModelWrapper from trl import PreTrainedModelWrapper
from ..data import Template from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@ -78,20 +78,30 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if image is not None: if image is not None:
mm_input_dict.update({"images": [image], "imglens": [1]})
if IMAGE_PLACEHOLDER not in messages[0]["content"]: if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
messages = template.mm_plugin.process_messages(messages, [image], processor) if video is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"] system = system or generating_args["default_system"]
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools) prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
if image is not None: prompt_ids, _ = template.mm_plugin.process_token_ids(
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, [image], tokenizer, processor) prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
@ -154,10 +164,7 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )
if image is not None: mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
mm_inputs = template.mm_plugin.get_mm_inputs(
images=[image], imglens=[1], seqlens=[prompt_length], processor=processor
)
for key, value in mm_inputs.items(): for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value) value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
gen_kwargs[key] = value.to(model.device) gen_kwargs[key] = value.to(model.device)
@ -175,11 +182,12 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]: ) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
@ -210,11 +218,12 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
@ -267,7 +276,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
if not self.can_generate: if not self.can_generate:
@ -284,6 +294,7 @@ class HuggingfaceEngine(BaseEngine):
system, system,
tools, tools,
image, image,
video,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
@ -295,7 +306,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["Image"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
@ -312,6 +324,7 @@ class HuggingfaceEngine(BaseEngine):
system, system,
tools, tools,
image, image,
video,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:

View File

@ -25,7 +25,7 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments
from .mm_plugin import ImageInput from .mm_plugin import ImageInput, VideoInput
from .parser import DatasetAttr from .parser import DatasetAttr
@ -52,6 +52,26 @@ def _convert_images(
return images return images
def _convert_videos(
videos: Sequence["VideoInput"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if len(videos) == 0:
return None
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
videos[i] = os.path.join(data_args.dataset_dir, videos[i])
return videos
def convert_alpaca( def convert_alpaca(
example: Dict[str, Any], example: Dict[str, Any],
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
@ -96,12 +116,14 @@ def convert_alpaca(
response = [] response = []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = { output = {
"_prompt": prompt, "_prompt": prompt,
"_response": response, "_response": response,
"_system": example[dataset_attr.system] if dataset_attr.system else "", "_system": example[dataset_attr.system] if dataset_attr.system else "",
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "", "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None, "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
} }
return output return output
@ -187,12 +209,14 @@ def convert_sharegpt(
prompt, response = [], [] prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = { output = {
"_prompt": prompt, "_prompt": prompt,
"_response": response, "_response": response,
"_system": system, "_system": system,
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "", "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None, "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
} }
return output return output
@ -210,6 +234,7 @@ def align_dataset(
_system: "..." _system: "..."
_tools: "...", _tools: "...",
_images: [], _images: [],
_videos: [],
""" """
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args) convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)

View File

@ -79,14 +79,19 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
processor: Optional["ProcessorMixin"] = None processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_imglens, batch_seqlens = [], [], [] batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
for feature in features: for feature in features:
images = feature.pop("images") or [] # avoid NoneType images = feature.pop("images") or [] # avoid NoneType
videos = feature.pop("videos") or []
batch_images.extend(images) batch_images.extend(images)
batch_videos.extend(videos)
batch_imglens.append(len(images)) batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_seqlens.append(len(feature["input_ids"])) batch_seqlens.append(len(feature["input_ids"]))
mm_inputs = self.template.mm_plugin.get_mm_inputs(batch_images, batch_imglens, batch_seqlens, self.processor) mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
)
if "token_type_ids" in mm_inputs: if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids") token_type_ids = mm_inputs.pop("token_type_ids")
for i, feature in enumerate(features): for i, feature in enumerate(features):
@ -136,6 +141,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"attention_mask": feature["{}_attention_mask".format(key)], "attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)], "labels": feature["{}_labels".format(key)],
"images": feature["images"], "images": feature["images"],
"videos": feature["videos"],
} }
concatenated_features.append(target_feature) concatenated_features.append(target_feature)
@ -158,12 +164,14 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"attention_mask": feature["attention_mask"], "attention_mask": feature["attention_mask"],
"labels": feature["labels"], "labels": feature["labels"],
"images": feature["images"], "images": feature["images"],
"videos": feature["videos"],
} }
kl_feature = { kl_feature = {
"input_ids": feature["kl_input_ids"], "input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"], "attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"], "labels": feature["kl_labels"],
"images": feature["images"], "images": feature["images"],
"videos": feature["videos"],
} }
target_features.append(target_feature) target_features.append(target_feature)
kl_features.append(kl_feature) kl_features.append(kl_feature)

View File

@ -2,11 +2,10 @@ from copy import deepcopy
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
from PIL.Image import Image import numpy as np
from transformers import ProcessorMixin
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import is_pillow_available from ..extras.packages import is_pillow_available, is_pyav_available
if is_pillow_available(): if is_pillow_available():
@ -14,8 +13,13 @@ if is_pillow_available():
from PIL.Image import Image as ImageObject from PIL.Image import Image as ImageObject
if is_pyav_available():
import av
if TYPE_CHECKING: if TYPE_CHECKING:
import torch import torch
from numpy.typing import NDArray
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
@ -24,13 +28,14 @@ if TYPE_CHECKING:
bytes: Optional[bytes] bytes: Optional[bytes]
ImageInput = Union[str, EncodedImage, ImageObject] ImageInput = Union[str, EncodedImage, ImageObject]
VideoInput = str
def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]: def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]:
r""" r"""
Regularizes images to avoid error. Including reading, resizing and converting. Regularizes images to avoid error. Including reading, resizing and converting.
""" """
image_resolution = getattr(processor, "image_resolution", 512) image_resolution: int = getattr(processor, "image_resolution", 512)
results = [] results = []
for image in images: for image in images:
if isinstance(image, str): if isinstance(image, str):
@ -56,7 +61,37 @@ def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixi
return results return results
def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixin") -> List["NDArray"]:
r"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
video_fps: float = getattr(processor, "video_fps", 1.0)
video_factor: int = getattr(processor, "video_factor", 1)
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
frames = _regularize_images(frames, processor)
results.append(frames)
return results
def _get_mm_inputs(
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r""" r"""
Processes visual inputs. Processes visual inputs.
@ -70,13 +105,19 @@ def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin")
It holds num_patches == torch.prod(image_grid_thw) It holds num_patches == torch.prod(image_grid_thw)
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
input_dict = {"images": None, "videos": None}
if len(images) != 0: if len(images) != 0:
images = _regularize_images(images, processor) images = _regularize_images(images, processor)
image_inputs = image_processor(images=images, return_tensors="pt") input_dict["images"] = images
else:
image_inputs = {}
return image_inputs if len(videos) != 0:
videos = _regularize_videos(videos, processor)
input_dict["videos"] = videos
if input_dict["images"] is not None or input_dict["videos"] is not None:
return image_processor(**input_dict, return_tensors="pt")
else:
return {}
def _get_paligemma_token_type_ids( def _get_paligemma_token_type_ids(
@ -97,18 +138,32 @@ def _get_paligemma_token_type_ids(
class BasePlugin: class BasePlugin:
def __init__(self, image_token: str) -> None: def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
self.image_token = image_token self.image_token = image_token
self.video_token = video_token
def _validate_input(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
) -> None:
if len(images) != 0 and self.image_token is None:
raise ValueError("This model does not support image input.")
if len(videos) != 0 and self.video_token is None:
raise ValueError("This model does not support video input.")
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
r""" r"""
Pre-processes input messages before tokenization for VLMs. Pre-processes input messages before tokenization for VLMs.
""" """
self._validate_input(images, videos)
return messages return messages
def process_token_ids( def process_token_ids(
@ -116,24 +171,29 @@ class BasePlugin:
input_ids: List[int], input_ids: List[int],
labels: Optional[List[int]], labels: Optional[List[int]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> Tuple[List[int], Optional[List[int]]]:
r""" r"""
Pre-processes token ids after tokenization for VLMs. Pre-processes token ids after tokenization for VLMs.
""" """
self._validate_input(images, videos)
return input_ids, labels return input_ids, labels
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int], seqlens: Sequence[int],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r""" r"""
Builds batched multimodal inputs for VLMs. Builds batched multimodal inputs for VLMs.
""" """
self._validate_input(images, videos)
return {} return {}
@ -142,8 +202,10 @@ class LlavaPlugin(BasePlugin):
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0 num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") image_seqlen = getattr(processor, "image_seqlen")
messages = deepcopy(messages) messages = deepcopy(messages)
@ -163,11 +225,14 @@ class LlavaPlugin(BasePlugin):
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int], seqlens: Sequence[int],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
return _get_mm_inputs(images, processor) self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor)
class PaliGemmaPlugin(BasePlugin): class PaliGemmaPlugin(BasePlugin):
@ -175,8 +240,10 @@ class PaliGemmaPlugin(BasePlugin):
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
@ -197,9 +264,11 @@ class PaliGemmaPlugin(BasePlugin):
input_ids: List[int], input_ids: List[int],
labels: Optional[List[int]], labels: Optional[List[int]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos)
num_images = len(images) num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") image_seqlen = num_images * getattr(processor, "image_seqlen")
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
@ -212,11 +281,14 @@ class PaliGemmaPlugin(BasePlugin):
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int], seqlens: Sequence[int],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
mm_inputs = _get_mm_inputs(images, processor) self._validate_input(images, videos)
mm_inputs = _get_mm_inputs(images, videos, processor)
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs return mm_inputs
@ -226,16 +298,17 @@ class Qwen2vlPlugin(BasePlugin):
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2 merge_length: int = getattr(image_processor, "merge_size") ** 2
if len(images) != 0: mm_inputs = _get_mm_inputs(images, videos, processor)
image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"] image_grid_thw = mm_inputs.get("image_grid_thw", [])
else: video_grid_thw = mm_inputs.get("video_grid_thw", [])
image_grid_thw = []
num_image_tokens = 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
@ -252,21 +325,40 @@ class Qwen2vlPlugin(BasePlugin):
) )
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
content = content.replace(
VIDEO_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
),
1,
)
num_video_tokens += 1
message["content"] = content message["content"] = content
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
return messages return messages
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int], seqlens: Sequence[int],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
return _get_mm_inputs(images, processor) self._validate_input(images, videos)
return _get_mm_inputs(images, videos, processor)
PLUGINS = { PLUGINS = {
@ -277,9 +369,13 @@ PLUGINS = {
} }
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin": def get_mm_plugin(
name: str,
image_token: Optional[str] = None,
video_token: Optional[str] = None,
) -> "BasePlugin":
plugin_class = PLUGINS.get(name, None) plugin_class = PLUGINS.get(name, None)
if plugin_class is None: if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name)) raise ValueError("Multimodal plugin `{}` not found.".format(name))
return plugin_class(image_token) return plugin_class(image_token, video_token)

View File

@ -43,6 +43,7 @@ class DatasetAttr:
system: Optional[str] = None system: Optional[str] = None
tools: Optional[str] = None tools: Optional[str] = None
images: Optional[str] = None images: Optional[str] = None
videos: Optional[str] = None
# rlhf columns # rlhf columns
chosen: Optional[str] = None chosen: Optional[str] = None
rejected: Optional[str] = None rejected: Optional[str] = None
@ -126,7 +127,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_attr.set_attr("num_samples", dataset_info[name]) dataset_attr.set_attr("num_samples", dataset_info[name])
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"]) column_names.extend(["prompt", "query", "response", "history"])
else: else:

View File

@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
@ -38,6 +38,7 @@ def _encode_feedback_example(
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
@ -55,8 +56,8 @@ def _encode_feedback_example(
else: else:
kl_messages = prompt + [kl_response[1]] kl_messages = prompt + [kl_response[1]]
messages = template.mm_plugin.process_messages(messages, images, processor) messages = template.mm_plugin.process_messages(messages, images, videos, processor)
kl_messages = template.mm_plugin.process_messages(kl_messages, images, processor) kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
@ -64,8 +65,8 @@ def _encode_feedback_example(
response_ids += [tokenizer.eos_token_id] response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id] kl_response_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor) prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, tokenizer, processor) kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len) source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
@ -103,6 +104,7 @@ def preprocess_feedback_dataset(
system=examples["_system"][i], system=examples["_system"][i],
tools=examples["_tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [], images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -116,6 +118,7 @@ def preprocess_feedback_dataset(
model_inputs["kl_labels"].append(kl_labels) model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag) model_inputs["kto_tags"].append(kto_tag)
model_inputs["images"].append(examples["_images"][i]) model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num undesirable_num = len(model_inputs["kto_tags"]) - desirable_num

View File

@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
@ -37,13 +37,14 @@ def _encode_pairwise_example(
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]: ) -> Tuple[List[int], List[int], List[int], List[int]]:
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor) chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor)
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor) rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor)
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools) _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
@ -51,7 +52,7 @@ def _encode_pairwise_example(
chosen_ids += [tokenizer.eos_token_id] chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id] rejected_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor) prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
# consider the response is more important # consider the response is more important
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len) source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
@ -85,6 +86,7 @@ def preprocess_pairwise_dataset(
system=examples["_system"][i], system=examples["_system"][i],
tools=examples["_tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [], images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -97,6 +99,7 @@ def preprocess_pairwise_dataset(
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels) model_inputs["rejected_labels"].append(rejected_labels)
model_inputs["images"].append(examples["_images"][i]) model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs return model_inputs

View File

@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
@ -37,6 +37,7 @@ def _encode_supervised_example(
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
@ -44,8 +45,8 @@ def _encode_supervised_example(
train_on_prompt: bool, train_on_prompt: bool,
mask_history: bool, mask_history: bool,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
messages = template.mm_plugin.process_messages(prompt + response, images, processor) messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor) input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if template.efficient_eos else 0) total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history: if mask_history:
@ -107,6 +108,7 @@ def preprocess_supervised_dataset(
system=examples["_system"][i], system=examples["_system"][i],
tools=examples["_tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [], images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -118,6 +120,7 @@ def preprocess_supervised_dataset(
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i]) model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs return model_inputs
@ -132,11 +135,8 @@ def preprocess_packed_supervised_dataset(
# TODO: use `position_ids` to achieve packing # TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
if processor is not None:
raise NotImplementedError("`packing` have not been implemented for multimodal datasets.")
valid_num = 0 valid_num = 0
batch_input_ids, batch_labels = [], [] batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
lengths = [] lengths = []
length2indexes = defaultdict(list) length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
@ -150,9 +150,10 @@ def preprocess_packed_supervised_dataset(
system=examples["_system"][i], system=examples["_system"][i],
tools=examples["_tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [], images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=None, processor=processor,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
train_on_prompt=data_args.train_on_prompt, train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history, mask_history=data_args.mask_history,
@ -165,16 +166,21 @@ def preprocess_packed_supervised_dataset(
length2indexes[length].append(valid_num) length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids) batch_input_ids.append(input_ids)
batch_labels.append(labels) batch_labels.append(labels)
batch_images.append(examples["_images"][i] or [])
batch_videos.append(examples["_videos"][i] or [])
valid_num += 1 valid_num += 1
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], [] packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos = [], []
for i, length in enumerate(knapsack): for i, length in enumerate(knapsack):
index = length2indexes[length].pop() index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index] packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index] packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing: if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else: else:
@ -195,7 +201,8 @@ def preprocess_packed_supervised_dataset(
model_inputs["input_ids"].append(packed_input_ids) model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels) model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(examples["_images"][i]) model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
return model_inputs return model_inputs

View File

@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
@ -37,6 +37,7 @@ def _encode_unsupervised_example(
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
@ -47,12 +48,12 @@ def _encode_unsupervised_example(
else: else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, processor) messages = template.mm_plugin.process_messages(messages, images, videos, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos: if template.efficient_eos:
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, tokenizer, processor) input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len) source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len] input_ids = input_ids[:source_len]
labels = labels[:target_len] labels = labels[:target_len]
@ -79,6 +80,7 @@ def preprocess_unsupervised_dataset(
system=examples["_system"][i], system=examples["_system"][i],
tools=examples["_tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [], images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -88,6 +90,7 @@ def preprocess_unsupervised_dataset(
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i]) model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs return model_inputs

View File

@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .data_utils import Role from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
@ -213,7 +212,7 @@ def _register_template(
stop_words: Sequence[str] = [], stop_words: Sequence[str] = [],
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base", image_token=IMAGE_PLACEHOLDER), mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None: ) -> None:
r""" r"""
Registers a chat template. Registers a chat template.
@ -826,7 +825,7 @@ _register_template(
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>"), mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
) )

View File

@ -95,6 +95,8 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = "<video>"
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"

View File

@ -38,6 +38,10 @@ def _get_package_version(name: str) -> "Version":
return version.parse("0.0.0") return version.parse("0.0.0")
def is_pyav_available():
return _is_package_available("av")
def is_fastapi_available(): def is_fastapi_available():
return _is_package_available("fastapi") return _is_package_available("fastapi")

View File

@ -142,6 +142,10 @@ class ModelArguments:
default=512, default=512,
metadata={"help": "Keeps the height or width of image below this resolution."}, metadata={"help": "Keeps the height or width of image below this resolution."},
) )
video_fps: float = field(
default=2.0,
metadata={"help": "The frames to sample per second for video training."},
)
infer_backend: Literal["huggingface", "vllm"] = field( infer_backend: Literal["huggingface", "vllm"] = field(
default="huggingface", default="huggingface",
metadata={"help": "Backend engine used at inference."}, metadata={"help": "Backend engine used at inference."},

View File

@ -100,6 +100,11 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
setattr(processor, "tokenizer", tokenizer) setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config)) setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution) setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "video_fps", model_args.video_fps)
if getattr(config, "model_type", None) == "qwen2_vl":
setattr(processor, "video_factor", 2)
else:
setattr(processor, "video_factor", 1)
except Exception: except Exception:
processor = None processor = None

View File

@ -133,6 +133,7 @@ class WebChatModel(ChatModel):
system: str, system: str,
tools: str, tools: str,
image: Optional[Any], image: Optional[Any],
video: Optional[Any],
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float, temperature: float,
@ -140,7 +141,7 @@ class WebChatModel(ChatModel):
chatbot[-1][1] = "" chatbot[-1][1] = ""
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature messages, system, tools, image, video, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
if tools: if tools:

View File

@ -43,9 +43,13 @@ def create_chat_box(
system = gr.Textbox(show_label=False) system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=3) tools = gr.Textbox(show_label=False, lines=3)
with gr.Column() as image_box: with gr.Column() as mm_box:
with gr.Tab("Image"):
image = gr.Image(sources=["upload"], type="pil") image = gr.Image(sources=["upload"], type="pil")
with gr.Tab("Video"):
video = gr.Video(sources=["upload"])
query = gr.Textbox(show_label=False, lines=8) query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary") submit_btn = gr.Button(variant="primary")
@ -63,7 +67,7 @@ def create_chat_box(
[chatbot, messages, query], [chatbot, messages, query],
).then( ).then(
engine.chatter.stream, engine.chatter.stream,
[chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature], [chatbot, messages, system, tools, image, video, max_new_tokens, top_p, temperature],
[chatbot, messages], [chatbot, messages],
) )
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
@ -76,8 +80,9 @@ def create_chat_box(
role=role, role=role,
system=system, system=system,
tools=tools, tools=tools,
image_box=image_box, mm_box=mm_box,
image=image, image=image,
video=video,
query=query, query=query,
submit_btn=submit_btn, submit_btn=submit_btn,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,

View File

@ -68,7 +68,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_id("top.model_name").change( engine.manager.get_elem_by_id("top.model_name").change(
lambda model_name: gr.Column(visible=get_visual(model_name)), lambda model_name: gr.Column(visible=get_visual(model_name)),
[engine.manager.get_elem_by_id("top.model_name")], [engine.manager.get_elem_by_id("top.model_name")],
[chat_elems["image_box"]], [chat_elems["mm_box"]],
) )
return elem_dict return elem_dict

View File

@ -59,7 +59,7 @@ class Engine:
init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)} init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)} init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)} init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
init_dict["infer.image_box"] = {"visible": False} init_dict["infer.mm_box"] = {"visible": False}
if user_config.get("last_model", None): if user_config.get("last_model", None):
init_dict["top.model_name"] = {"value": user_config["last_model"]} init_dict["top.model_name"] = {"value": user_config["last_model"]}

View File

@ -1691,6 +1691,20 @@ LOCALES = {
"label": "이미지 (선택 사항)", "label": "이미지 (선택 사항)",
}, },
}, },
"video": {
"en": {
"label": "Video (optional)",
},
"ru": {
"label": "Видео (по желанию)",
},
"zh": {
"label": "视频(非必填)",
},
"ko": {
"label": "비디오 (선택 사항)",
},
},
"query": { "query": {
"en": { "en": {
"placeholder": "Input...", "placeholder": "Input...",

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Any, Dict, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
import pytest import pytest
import torch import torch
@ -28,6 +28,8 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
from llamafactory.data.mm_plugin import BasePlugin
HF_TOKEN = os.environ.get("HF_TOKEN", None) HF_TOKEN = os.environ.get("HF_TOKEN", None)
@ -47,10 +49,14 @@ IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
NO_IMAGES = [] NO_IMAGES = []
NO_VIDEOS = []
IMGLENS = [1] IMGLENS = [1]
NO_IMGLENS = [0] NO_IMGLENS = [0]
NO_VIDLENS = [0]
INPUT_IDS = [0, 1, 2, 3, 4] INPUT_IDS = [0, 1, 2, 3, 4]
LABELS = [0, 1, 2, 3, 4] LABELS = [0, 1, 2, 3, 4]
@ -78,92 +84,86 @@ def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenize
return tokenizer_module["tokenizer"], tokenizer_module["processor"] return tokenizer_module["tokenizer"], tokenizer_module["processor"]
def _check_plugin(
plugin: "BasePlugin",
tokenizer: "PreTrainedTokenizer",
processor: "ProcessorMixin",
expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES,
expected_input_ids: List[int] = INPUT_IDS,
expected_labels: List[int] = LABELS,
expected_mm_inputs: Dict[str, Any] = {},
expected_no_mm_inputs: Dict[str, Any] = {},
) -> None:
# test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, IMGLENS, NO_VIDLENS, SEQLENS, processor),
expected_mm_inputs,
)
# test text_messages
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, processor) == TEXT_MESSAGES
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, tokenizer, processor) == (
INPUT_IDS,
LABELS,
)
_is_close(
plugin.get_mm_inputs(NO_IMAGES, NO_VIDEOS, NO_IMGLENS, NO_VIDLENS, SEQLENS, processor),
expected_no_mm_inputs,
)
def test_base_plugin(): def test_base_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA) tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>") base_plugin = get_mm_plugin(name="base", image_token="<image>")
# test mm_messages check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES _check_plugin(**check_inputs)
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(base_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), {})
# test text_messages
assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
def test_llava_plugin(): def test_llava_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf") tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
image_seqlen = 576 image_seqlen = 576
check_inputs = {"plugin": llava_plugin, "tokenizer": tokenizer, "processor": processor}
mm_inputs = _get_mm_inputs(processor) check_inputs["expected_mm_messages"] = [
expected_mm_messages = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()} {key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES for message in MM_MESSAGES
] ]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
llava_plugin = get_mm_plugin(name="llava", image_token="<image>") _check_plugin(**check_inputs)
# test mm_messages
assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(llava_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
# test text_messages
assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin(): def test_paligemma_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224") tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
image_seqlen = 256 image_seqlen = 256
check_inputs = {"plugin": paligemma_plugin, "tokenizer": tokenizer, "processor": processor}
mm_inputs = _get_mm_inputs(processor) check_inputs["expected_mm_messages"] = [
mm_inputs["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
expected_mm_messages = [
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES {key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
] ]
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS check_inputs["expected_input_ids"] = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
expected_labels = [-100] * image_seqlen + LABELS check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>") check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
# test mm_messages check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
assert paligemma_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages _check_plugin(**check_inputs)
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
# test text_messages
assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
INPUT_IDS,
LABELS,
)
_is_close(
paligemma_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor),
{"token_type_ids": [[1] * 1024]},
)
def test_qwen2_vl_plugin(): def test_qwen2_vl_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
image_seqlen = 4 image_seqlen = 4
check_inputs = {"plugin": qwen2_vl_plugin, "tokenizer": tokenizer, "processor": processor}
mm_inputs = _get_mm_inputs(processor) check_inputs["expected_mm_messages"] = [
expected_mm_messages = [
{ {
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen)) key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
for key, value in message.items() for key, value in message.items()
} }
for message in MM_MESSAGES for message in MM_MESSAGES
] ]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>") _check_plugin(**check_inputs)
# test mm_messages
assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
# test text_messages
assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})