diff --git a/data/mllm_demo_data/1.mp4 b/data/mllm_demo_data/1.mp4 new file mode 100644 index 00000000..f3abd568 Binary files /dev/null and b/data/mllm_demo_data/1.mp4 differ diff --git a/data/mllm_demo_data/2.avi b/data/mllm_demo_data/2.avi new file mode 100644 index 00000000..bdb736c2 Binary files /dev/null and b/data/mllm_demo_data/2.avi differ diff --git a/data/mllm_demo_data/3.mp4 b/data/mllm_demo_data/3.mp4 new file mode 100644 index 00000000..48ce6f66 Binary files /dev/null and b/data/mllm_demo_data/3.mp4 differ diff --git a/src/llamafactory/chat/base_engine.py b/src/llamafactory/chat/base_engine.py index ccdf4c92..a95e0e3a 100644 --- a/src/llamafactory/chat/base_engine.py +++ b/src/llamafactory/chat/base_engine.py @@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti if TYPE_CHECKING: - from numpy.typing import NDArray from transformers import PreTrainedModel, PreTrainedTokenizer from vllm import AsyncLLMEngine from ..data import Template + from ..data.mm_plugin import ImageInput, VideoInput from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments @@ -56,7 +56,8 @@ class BaseEngine(ABC): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> List["Response"]: ... @@ -66,7 +67,8 @@ class BaseEngine(ABC): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: ... diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 6df83b57..3eb8124f 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -27,8 +27,7 @@ from .vllm_engine import VllmEngine if TYPE_CHECKING: - from PIL.Image import Image - + from ..data.mm_plugin import ImageInput, VideoInput from .base_engine import BaseEngine, Response @@ -56,10 +55,13 @@ class ChatModel: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> List["Response"]: - task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop) + task = asyncio.run_coroutine_threadsafe( + self.achat(messages, system, tools, image, video, **input_kwargs), self._loop + ) return task.result() async def achat( @@ -67,20 +69,22 @@ class ChatModel: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> List["Response"]: - return await self.engine.chat(messages, system, tools, image, **input_kwargs) + return await self.engine.chat(messages, system, tools, image, video, **input_kwargs) def stream_chat( self, messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> Generator[str, None, None]: - generator = self.astream_chat(messages, system, tools, image, **input_kwargs) + generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs) while True: try: task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) @@ -93,10 +97,11 @@ class ChatModel: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: - async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs): + async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs): yield new_token def get_scores( diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 880e5803..b1a9b078 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -22,7 +22,7 @@ import torch from transformers import GenerationConfig, TextIteratorStreamer from ..data import get_template_and_fix_tokenizer -from ..extras.constants import IMAGE_PLACEHOLDER +from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.logging import get_logger from ..extras.misc import get_logits_processor from ..model import load_model, load_tokenizer @@ -30,11 +30,11 @@ from .base_engine import BaseEngine, Response if TYPE_CHECKING: - from PIL.Image import Image from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from trl import PreTrainedModelWrapper from ..data import Template + from ..data.mm_plugin import ImageInput, VideoInput from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments @@ -78,20 +78,30 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> Tuple[Dict[str, Any], int]: + mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]} if image is not None: + mm_input_dict.update({"images": [image], "imglens": [1]}) if IMAGE_PLACEHOLDER not in messages[0]["content"]: messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] - messages = template.mm_plugin.process_messages(messages, [image], processor) + if video is not None: + mm_input_dict.update({"videos": [video], "vidlens": [1]}) + if VIDEO_PLACEHOLDER not in messages[0]["content"]: + messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"] + messages = template.mm_plugin.process_messages( + messages, mm_input_dict["images"], mm_input_dict["videos"], processor + ) paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or generating_args["default_system"] prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools) - if image is not None: - prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, [image], tokenizer, processor) + prompt_ids, _ = template.mm_plugin.process_token_ids( + prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor + ) prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) @@ -154,13 +164,10 @@ class HuggingfaceEngine(BaseEngine): logits_processor=get_logits_processor(), ) - if image is not None: - mm_inputs = template.mm_plugin.get_mm_inputs( - images=[image], imglens=[1], seqlens=[prompt_length], processor=processor - ) - for key, value in mm_inputs.items(): - value = value if isinstance(value, torch.Tensor) else torch.tensor(value) - gen_kwargs[key] = value.to(model.device) + mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor) + for key, value in mm_inputs.items(): + value = value if isinstance(value, torch.Tensor) else torch.tensor(value) + gen_kwargs[key] = value.to(model.device) return gen_kwargs, prompt_length @@ -175,11 +182,12 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> List["Response"]: gen_kwargs, prompt_length = HuggingfaceEngine._process_args( - model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs + model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs ) generate_output = model.generate(**gen_kwargs) response_ids = generate_output[:, prompt_length:] @@ -210,11 +218,12 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> Callable[[], str]: gen_kwargs, _ = HuggingfaceEngine._process_args( - model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs + model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs ) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer @@ -267,7 +276,8 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> List["Response"]: if not self.can_generate: @@ -284,6 +294,7 @@ class HuggingfaceEngine(BaseEngine): system, tools, image, + video, input_kwargs, ) async with self.semaphore: @@ -295,7 +306,8 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: if not self.can_generate: @@ -312,6 +324,7 @@ class HuggingfaceEngine(BaseEngine): system, tools, image, + video, input_kwargs, ) async with self.semaphore: diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 05ff7ef0..f7f6d0ad 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments from ..hparams import DataArguments - from .mm_plugin import ImageInput + from .mm_plugin import ImageInput, VideoInput from .parser import DatasetAttr @@ -52,6 +52,26 @@ def _convert_images( return images +def _convert_videos( + videos: Sequence["VideoInput"], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", +) -> Optional[List["VideoInput"]]: + r""" + Optionally concatenates video path to dataset dir when loading from local disk. + """ + if len(videos) == 0: + return None + + videos = videos[:] + if dataset_attr.load_from in ["script", "file"]: + for i in range(len(videos)): + if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])): + videos[i] = os.path.join(data_args.dataset_dir, videos[i]) + + return videos + + def convert_alpaca( example: Dict[str, Any], dataset_attr: "DatasetAttr", @@ -96,12 +116,14 @@ def convert_alpaca( response = [] convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) + convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args) output = { "_prompt": prompt, "_response": response, "_system": example[dataset_attr.system] if dataset_attr.system else "", "_tools": example[dataset_attr.tools] if dataset_attr.tools else "", "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None, + "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None, } return output @@ -187,12 +209,14 @@ def convert_sharegpt( prompt, response = [], [] convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) + convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args) output = { "_prompt": prompt, "_response": response, "_system": system, "_tools": example[dataset_attr.tools] if dataset_attr.tools else "", "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None, + "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None, } return output @@ -210,6 +234,7 @@ def align_dataset( _system: "..." _tools: "...", _images: [], + _videos: [], """ if dataset_attr.formatting == "alpaca": convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 73508b47..d86c5c43 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -79,14 +79,19 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): processor: Optional["ProcessorMixin"] = None def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: - batch_images, batch_imglens, batch_seqlens = [], [], [] + batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], [] for feature in features: images = feature.pop("images") or [] # avoid NoneType + videos = feature.pop("videos") or [] batch_images.extend(images) + batch_videos.extend(videos) batch_imglens.append(len(images)) + batch_vidlens.append(len(videos)) batch_seqlens.append(len(feature["input_ids"])) - mm_inputs = self.template.mm_plugin.get_mm_inputs(batch_images, batch_imglens, batch_seqlens, self.processor) + mm_inputs = self.template.mm_plugin.get_mm_inputs( + batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor + ) if "token_type_ids" in mm_inputs: token_type_ids = mm_inputs.pop("token_type_ids") for i, feature in enumerate(features): @@ -136,6 +141,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): "attention_mask": feature["{}_attention_mask".format(key)], "labels": feature["{}_labels".format(key)], "images": feature["images"], + "videos": feature["videos"], } concatenated_features.append(target_feature) @@ -158,12 +164,14 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): "attention_mask": feature["attention_mask"], "labels": feature["labels"], "images": feature["images"], + "videos": feature["videos"], } kl_feature = { "input_ids": feature["kl_input_ids"], "attention_mask": feature["kl_attention_mask"], "labels": feature["kl_labels"], "images": feature["images"], + "videos": feature["videos"], } target_features.append(target_feature) kl_features.append(kl_feature) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 5e1b5bd8..33ab1328 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -2,11 +2,10 @@ from copy import deepcopy from io import BytesIO from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union -from PIL.Image import Image -from transformers import ProcessorMixin +import numpy as np -from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER -from ..extras.packages import is_pillow_available +from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER +from ..extras.packages import is_pillow_available, is_pyav_available if is_pillow_available(): @@ -14,8 +13,13 @@ if is_pillow_available(): from PIL.Image import Image as ImageObject +if is_pyav_available(): + import av + + if TYPE_CHECKING: import torch + from numpy.typing import NDArray from transformers import PreTrainedTokenizer, ProcessorMixin from transformers.image_processing_utils import BaseImageProcessor @@ -24,13 +28,14 @@ if TYPE_CHECKING: bytes: Optional[bytes] ImageInput = Union[str, EncodedImage, ImageObject] + VideoInput = str def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]: r""" Regularizes images to avoid error. Including reading, resizing and converting. """ - image_resolution = getattr(processor, "image_resolution", 512) + image_resolution: int = getattr(processor, "image_resolution", 512) results = [] for image in images: if isinstance(image, str): @@ -56,7 +61,37 @@ def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixi return results -def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: +def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixin") -> List["NDArray"]: + r""" + Regularizes videos to avoid error. Including reading, resizing and converting. + """ + video_fps: float = getattr(processor, "video_fps", 1.0) + video_factor: int = getattr(processor, "video_factor", 1) + results = [] + for video in videos: + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + total_frames = video_stream.frames + sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps + sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl + sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) + frames: List["ImageObject"] = [] + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) + + frames = _regularize_images(frames, processor) + results.append(frames) + + return results + + +def _get_mm_inputs( + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: "ProcessorMixin", +) -> Dict[str, "torch.Tensor"]: r""" Processes visual inputs. @@ -70,13 +105,19 @@ def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin") It holds num_patches == torch.prod(image_grid_thw) """ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + input_dict = {"images": None, "videos": None} if len(images) != 0: images = _regularize_images(images, processor) - image_inputs = image_processor(images=images, return_tensors="pt") - else: - image_inputs = {} + input_dict["images"] = images - return image_inputs + if len(videos) != 0: + videos = _regularize_videos(videos, processor) + input_dict["videos"] = videos + + if input_dict["images"] is not None or input_dict["videos"] is not None: + return image_processor(**input_dict, return_tensors="pt") + else: + return {} def _get_paligemma_token_type_ids( @@ -97,18 +138,32 @@ def _get_paligemma_token_type_ids( class BasePlugin: - def __init__(self, image_token: str) -> None: + def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None: self.image_token = image_token + self.video_token = video_token + + def _validate_input( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + ) -> None: + if len(images) != 0 and self.image_token is None: + raise ValueError("This model does not support image input.") + + if len(videos) != 0 and self.video_token is None: + raise ValueError("This model does not support video input.") def process_messages( self, messages: Sequence[Dict[str, str]], images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: r""" Pre-processes input messages before tokenization for VLMs. """ + self._validate_input(images, videos) return messages def process_token_ids( @@ -116,24 +171,29 @@ class BasePlugin: input_ids: List[int], labels: Optional[List[int]], images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], ) -> Tuple[List[int], Optional[List[int]]]: r""" Pre-processes token ids after tokenization for VLMs. """ + self._validate_input(images, videos) return input_ids, labels def get_mm_inputs( self, images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], imglens: Sequence[int], + vidlens: Sequence[int], seqlens: Sequence[int], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: r""" Builds batched multimodal inputs for VLMs. """ + self._validate_input(images, videos) return {} @@ -142,8 +202,10 @@ class LlavaPlugin(BasePlugin): self, messages: Sequence[Dict[str, str]], images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: + self._validate_input(images, videos) num_image_tokens = 0 image_seqlen = getattr(processor, "image_seqlen") messages = deepcopy(messages) @@ -163,11 +225,14 @@ class LlavaPlugin(BasePlugin): def get_mm_inputs( self, images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], imglens: Sequence[int], + vidlens: Sequence[int], seqlens: Sequence[int], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - return _get_mm_inputs(images, processor) + self._validate_input(images, videos) + return _get_mm_inputs(images, videos, processor) class PaliGemmaPlugin(BasePlugin): @@ -175,8 +240,10 @@ class PaliGemmaPlugin(BasePlugin): self, messages: Sequence[Dict[str, str]], images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: + self._validate_input(images, videos) num_image_tokens = 0 messages = deepcopy(messages) for message in messages: @@ -197,9 +264,11 @@ class PaliGemmaPlugin(BasePlugin): input_ids: List[int], labels: Optional[List[int]], images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], ) -> Tuple[List[int], Optional[List[int]]]: + self._validate_input(images, videos) num_images = len(images) image_seqlen = num_images * getattr(processor, "image_seqlen") image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) @@ -212,11 +281,14 @@ class PaliGemmaPlugin(BasePlugin): def get_mm_inputs( self, images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], imglens: Sequence[int], + vidlens: Sequence[int], seqlens: Sequence[int], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - mm_inputs = _get_mm_inputs(images, processor) + self._validate_input(images, videos) + mm_inputs = _get_mm_inputs(images, videos, processor) mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) return mm_inputs @@ -226,16 +298,17 @@ class Qwen2vlPlugin(BasePlugin): self, messages: Sequence[Dict[str, str]], images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: + self._validate_input(images, videos) image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") merge_length: int = getattr(image_processor, "merge_size") ** 2 - if len(images) != 0: - image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"] - else: - image_grid_thw = [] + mm_inputs = _get_mm_inputs(images, videos, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) - num_image_tokens = 0 + num_image_tokens, num_video_tokens = 0, 0 messages = deepcopy(messages) for message in messages: content = message["content"] @@ -252,21 +325,40 @@ class Qwen2vlPlugin(BasePlugin): ) num_image_tokens += 1 + while VIDEO_PLACEHOLDER in content: + if num_video_tokens >= len(video_grid_thw): + raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER)) + + content = content.replace( + VIDEO_PLACEHOLDER, + "<|vision_start|>{}<|vision_end|>".format( + self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length) + ), + 1, + ) + num_video_tokens += 1 + message["content"] = content if len(images) != num_image_tokens: raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + if len(videos) != num_video_tokens: + raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER)) + return messages def get_mm_inputs( self, images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], imglens: Sequence[int], + vidlens: Sequence[int], seqlens: Sequence[int], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - return _get_mm_inputs(images, processor) + self._validate_input(images, videos) + return _get_mm_inputs(images, videos, processor) PLUGINS = { @@ -277,9 +369,13 @@ PLUGINS = { } -def get_mm_plugin(name: str, image_token: str) -> "BasePlugin": +def get_mm_plugin( + name: str, + image_token: Optional[str] = None, + video_token: Optional[str] = None, +) -> "BasePlugin": plugin_class = PLUGINS.get(name, None) if plugin_class is None: raise ValueError("Multimodal plugin `{}` not found.".format(name)) - return plugin_class(image_token) + return plugin_class(image_token, video_token) diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 2dccfc5d..15a6eab8 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -43,6 +43,7 @@ class DatasetAttr: system: Optional[str] = None tools: Optional[str] = None images: Optional[str] = None + videos: Optional[str] = None # rlhf columns chosen: Optional[str] = None rejected: Optional[str] = None @@ -126,7 +127,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - dataset_attr.set_attr("num_samples", dataset_info[name]) if "columns" in dataset_info[name]: - column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] + column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"] if dataset_attr.formatting == "alpaca": column_names.extend(["prompt", "query", "response", "history"]) else: diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index 045182d9..a437c688 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments - from ..mm_plugin import ImageInput + from ..mm_plugin import ImageInput, VideoInput from ..template import Template @@ -38,6 +38,7 @@ def _encode_feedback_example( system: Optional[str], tools: Optional[str], images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], @@ -55,8 +56,8 @@ def _encode_feedback_example( else: kl_messages = prompt + [kl_response[1]] - messages = template.mm_plugin.process_messages(messages, images, processor) - kl_messages = template.mm_plugin.process_messages(kl_messages, images, processor) + messages = template.mm_plugin.process_messages(messages, images, videos, processor) + kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor) prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) @@ -64,8 +65,8 @@ def _encode_feedback_example( response_ids += [tokenizer.eos_token_id] kl_response_ids += [tokenizer.eos_token_id] - prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor) - kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, tokenizer, processor) + prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor) + kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor) source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len) prompt_ids = prompt_ids[:source_len] @@ -103,6 +104,7 @@ def preprocess_feedback_dataset( system=examples["_system"][i], tools=examples["_tools"][i], images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], template=template, tokenizer=tokenizer, processor=processor, @@ -116,6 +118,7 @@ def preprocess_feedback_dataset( model_inputs["kl_labels"].append(kl_labels) model_inputs["kto_tags"].append(kto_tag) model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) undesirable_num = len(model_inputs["kto_tags"]) - desirable_num diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index fa7e3fd2..05702fa3 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments - from ..mm_plugin import ImageInput + from ..mm_plugin import ImageInput, VideoInput from ..template import Template @@ -37,13 +37,14 @@ def _encode_pairwise_example( system: Optional[str], tools: Optional[str], images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, ) -> Tuple[List[int], List[int], List[int], List[int]]: - chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor) - rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor) + chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor) + rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools) @@ -51,7 +52,7 @@ def _encode_pairwise_example( chosen_ids += [tokenizer.eos_token_id] rejected_ids += [tokenizer.eos_token_id] - prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor) + prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor) # consider the response is more important source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len) prompt_ids = prompt_ids[:source_len] @@ -85,6 +86,7 @@ def preprocess_pairwise_dataset( system=examples["_system"][i], tools=examples["_tools"][i], images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], template=template, tokenizer=tokenizer, processor=processor, @@ -97,6 +99,7 @@ def preprocess_pairwise_dataset( model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) model_inputs["rejected_labels"].append(rejected_labels) model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) return model_inputs diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 00e5ed44..66625640 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments - from ..mm_plugin import ImageInput + from ..mm_plugin import ImageInput, VideoInput from ..template import Template @@ -37,6 +37,7 @@ def _encode_supervised_example( system: Optional[str], tools: Optional[str], images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], @@ -44,8 +45,8 @@ def _encode_supervised_example( train_on_prompt: bool, mask_history: bool, ) -> Tuple[List[int], List[int]]: - messages = template.mm_plugin.process_messages(prompt + response, images, processor) - input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor) + messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor) + input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) total_length = len(input_ids) + (1 if template.efficient_eos else 0) if mask_history: @@ -107,6 +108,7 @@ def preprocess_supervised_dataset( system=examples["_system"][i], tools=examples["_tools"][i], images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], template=template, tokenizer=tokenizer, processor=processor, @@ -118,6 +120,7 @@ def preprocess_supervised_dataset( model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) return model_inputs @@ -132,11 +135,8 @@ def preprocess_packed_supervised_dataset( # TODO: use `position_ids` to achieve packing # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` - if processor is not None: - raise NotImplementedError("`packing` have not been implemented for multimodal datasets.") - valid_num = 0 - batch_input_ids, batch_labels = [], [] + batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], [] lengths = [] length2indexes = defaultdict(list) for i in range(len(examples["_prompt"])): @@ -150,9 +150,10 @@ def preprocess_packed_supervised_dataset( system=examples["_system"][i], tools=examples["_tools"][i], images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], template=template, tokenizer=tokenizer, - processor=None, + processor=processor, cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token train_on_prompt=data_args.train_on_prompt, mask_history=data_args.mask_history, @@ -165,16 +166,21 @@ def preprocess_packed_supervised_dataset( length2indexes[length].append(valid_num) batch_input_ids.append(input_ids) batch_labels.append(labels) + batch_images.append(examples["_images"][i] or []) + batch_videos.append(examples["_videos"][i] or []) valid_num += 1 model_inputs = defaultdict(list) knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token for knapsack in knapsacks: packed_input_ids, packed_attention_masks, packed_labels = [], [], [] + packed_images, packed_videos = [], [] for i, length in enumerate(knapsack): index = length2indexes[length].pop() packed_input_ids += batch_input_ids[index] packed_labels += batch_labels[index] + packed_images += batch_images[index] + packed_videos += batch_videos[index] if data_args.neat_packing: packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 else: @@ -195,7 +201,8 @@ def preprocess_packed_supervised_dataset( model_inputs["input_ids"].append(packed_input_ids) model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["labels"].append(packed_labels) - model_inputs["images"].append(examples["_images"][i]) + model_inputs["images"].append(packed_images or None) + model_inputs["videos"].append(packed_videos or None) return model_inputs diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index 6f251969..0a96935b 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments - from ..mm_plugin import ImageInput + from ..mm_plugin import ImageInput, VideoInput from ..template import Template @@ -37,6 +37,7 @@ def _encode_unsupervised_example( system: Optional[str], tools: Optional[str], images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], @@ -47,12 +48,12 @@ def _encode_unsupervised_example( else: messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] - messages = template.mm_plugin.process_messages(messages, images, processor) + messages = template.mm_plugin.process_messages(messages, images, videos, processor) input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) if template.efficient_eos: labels += [tokenizer.eos_token_id] - input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, tokenizer, processor) + input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor) source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len) input_ids = input_ids[:source_len] labels = labels[:target_len] @@ -79,6 +80,7 @@ def preprocess_unsupervised_dataset( system=examples["_system"][i], tools=examples["_tools"][i], images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], template=template, tokenizer=tokenizer, processor=processor, @@ -88,6 +90,7 @@ def preprocess_unsupervised_dataset( model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) return model_inputs diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 818e5625..5160e0c2 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from transformers.utils.versions import require_version -from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.logging import get_logger from .data_utils import Role from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter @@ -213,7 +212,7 @@ def _register_template( stop_words: Sequence[str] = [], efficient_eos: bool = False, replace_eos: bool = False, - mm_plugin: "BasePlugin" = get_mm_plugin(name="base", image_token=IMAGE_PLACEHOLDER), + mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), ) -> None: r""" Registers a chat template. @@ -826,7 +825,7 @@ _register_template( default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], replace_eos=True, - mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>"), + mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index fc2d3460..60940b34 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -95,6 +95,8 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} +VIDEO_PLACEHOLDER = "