mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	video datasets
Former-commit-id: 33f28ce82d9e44d2615909250dc56d6a4a03cd99
This commit is contained in:
		
							parent
							
								
									2c1eef34cb
								
							
						
					
					
						commit
						1874d579c5
					
				
							
								
								
									
										
											BIN
										
									
								
								data/mllm_demo_data/1.mp4
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								data/mllm_demo_data/1.mp4
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								data/mllm_demo_data/2.avi
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								data/mllm_demo_data/2.avi
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								data/mllm_demo_data/3.mp4
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								data/mllm_demo_data/3.mp4
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
    from transformers import PreTrainedModel, PreTrainedTokenizer
 | 
			
		||||
    from vllm import AsyncLLMEngine
 | 
			
		||||
 | 
			
		||||
    from ..data import Template
 | 
			
		||||
    from ..data.mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,8 @@ class BaseEngine(ABC):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]: ...
 | 
			
		||||
 | 
			
		||||
@ -66,7 +67,8 @@ class BaseEngine(ABC):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]: ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -27,8 +27,7 @@ from .vllm_engine import VllmEngine
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from PIL.Image import Image
 | 
			
		||||
 | 
			
		||||
    from ..data.mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from .base_engine import BaseEngine, Response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,10 +55,13 @@ class ChatModel:
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["Image"] = None,
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
 | 
			
		||||
        task = asyncio.run_coroutine_threadsafe(
 | 
			
		||||
            self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
 | 
			
		||||
        )
 | 
			
		||||
        return task.result()
 | 
			
		||||
 | 
			
		||||
    async def achat(
 | 
			
		||||
@ -67,20 +69,22 @@ class ChatModel:
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["Image"] = None,
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        return await self.engine.chat(messages, system, tools, image, **input_kwargs)
 | 
			
		||||
        return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
 | 
			
		||||
 | 
			
		||||
    def stream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["Image"] = None,
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> Generator[str, None, None]:
 | 
			
		||||
        generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
 | 
			
		||||
        generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
 | 
			
		||||
@ -93,10 +97,11 @@ class ChatModel:
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["Image"] = None,
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
 | 
			
		||||
        async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
 | 
			
		||||
            yield new_token
 | 
			
		||||
 | 
			
		||||
    def get_scores(
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ import torch
 | 
			
		||||
from transformers import GenerationConfig, TextIteratorStreamer
 | 
			
		||||
 | 
			
		||||
from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
from ..extras.misc import get_logits_processor
 | 
			
		||||
from ..model import load_model, load_tokenizer
 | 
			
		||||
@ -30,11 +30,11 @@ from .base_engine import BaseEngine, Response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from PIL.Image import Image
 | 
			
		||||
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from trl import PreTrainedModelWrapper
 | 
			
		||||
 | 
			
		||||
    from ..data import Template
 | 
			
		||||
    from ..data.mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -78,20 +78,30 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["Image"] = None,
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        input_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
    ) -> Tuple[Dict[str, Any], int]:
 | 
			
		||||
        mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
 | 
			
		||||
        if image is not None:
 | 
			
		||||
            mm_input_dict.update({"images": [image], "imglens": [1]})
 | 
			
		||||
            if IMAGE_PLACEHOLDER not in messages[0]["content"]:
 | 
			
		||||
                messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
            messages = template.mm_plugin.process_messages(messages, [image], processor)
 | 
			
		||||
        if video is not None:
 | 
			
		||||
            mm_input_dict.update({"videos": [video], "vidlens": [1]})
 | 
			
		||||
            if VIDEO_PLACEHOLDER not in messages[0]["content"]:
 | 
			
		||||
                messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        messages = template.mm_plugin.process_messages(
 | 
			
		||||
            messages, mm_input_dict["images"], mm_input_dict["videos"], processor
 | 
			
		||||
        )
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        system = system or generating_args["default_system"]
 | 
			
		||||
        prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
 | 
			
		||||
        if image is not None:
 | 
			
		||||
            prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, [image], tokenizer, processor)
 | 
			
		||||
        prompt_ids, _ = template.mm_plugin.process_token_ids(
 | 
			
		||||
            prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        prompt_length = len(prompt_ids)
 | 
			
		||||
        inputs = torch.tensor([prompt_ids], device=model.device)
 | 
			
		||||
@ -154,13 +164,10 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            logits_processor=get_logits_processor(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if image is not None:
 | 
			
		||||
            mm_inputs = template.mm_plugin.get_mm_inputs(
 | 
			
		||||
                images=[image], imglens=[1], seqlens=[prompt_length], processor=processor
 | 
			
		||||
            )
 | 
			
		||||
            for key, value in mm_inputs.items():
 | 
			
		||||
                value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
 | 
			
		||||
                gen_kwargs[key] = value.to(model.device)
 | 
			
		||||
        mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
 | 
			
		||||
        for key, value in mm_inputs.items():
 | 
			
		||||
            value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
 | 
			
		||||
            gen_kwargs[key] = value.to(model.device)
 | 
			
		||||
 | 
			
		||||
        return gen_kwargs, prompt_length
 | 
			
		||||
 | 
			
		||||
@ -175,11 +182,12 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["Image"] = None,
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        input_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
 | 
			
		||||
            model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
 | 
			
		||||
            model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
 | 
			
		||||
        )
 | 
			
		||||
        generate_output = model.generate(**gen_kwargs)
 | 
			
		||||
        response_ids = generate_output[:, prompt_length:]
 | 
			
		||||
@ -210,11 +218,12 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["Image"] = None,
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        input_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
    ) -> Callable[[], str]:
 | 
			
		||||
        gen_kwargs, _ = HuggingfaceEngine._process_args(
 | 
			
		||||
            model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
 | 
			
		||||
            model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
 | 
			
		||||
        )
 | 
			
		||||
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 | 
			
		||||
        gen_kwargs["streamer"] = streamer
 | 
			
		||||
@ -267,7 +276,8 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["Image"] = None,
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        if not self.can_generate:
 | 
			
		||||
@ -284,6 +294,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            system,
 | 
			
		||||
            tools,
 | 
			
		||||
            image,
 | 
			
		||||
            video,
 | 
			
		||||
            input_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        async with self.semaphore:
 | 
			
		||||
@ -295,7 +306,8 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        image: Optional["Image"] = None,
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        if not self.can_generate:
 | 
			
		||||
@ -312,6 +324,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            system,
 | 
			
		||||
            tools,
 | 
			
		||||
            image,
 | 
			
		||||
            video,
 | 
			
		||||
            input_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        async with self.semaphore:
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import Seq2SeqTrainingArguments
 | 
			
		||||
 | 
			
		||||
    from ..hparams import DataArguments
 | 
			
		||||
    from .mm_plugin import ImageInput
 | 
			
		||||
    from .mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from .parser import DatasetAttr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,6 +52,26 @@ def _convert_images(
 | 
			
		||||
    return images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_videos(
 | 
			
		||||
    videos: Sequence["VideoInput"],
 | 
			
		||||
    dataset_attr: "DatasetAttr",
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Optional[List["VideoInput"]]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Optionally concatenates video path to dataset dir when loading from local disk.
 | 
			
		||||
    """
 | 
			
		||||
    if len(videos) == 0:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    videos = videos[:]
 | 
			
		||||
    if dataset_attr.load_from in ["script", "file"]:
 | 
			
		||||
        for i in range(len(videos)):
 | 
			
		||||
            if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
 | 
			
		||||
                videos[i] = os.path.join(data_args.dataset_dir, videos[i])
 | 
			
		||||
 | 
			
		||||
    return videos
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_alpaca(
 | 
			
		||||
    example: Dict[str, Any],
 | 
			
		||||
    dataset_attr: "DatasetAttr",
 | 
			
		||||
@ -96,12 +116,14 @@ def convert_alpaca(
 | 
			
		||||
        response = []
 | 
			
		||||
 | 
			
		||||
    convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
    convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
    output = {
 | 
			
		||||
        "_prompt": prompt,
 | 
			
		||||
        "_response": response,
 | 
			
		||||
        "_system": example[dataset_attr.system] if dataset_attr.system else "",
 | 
			
		||||
        "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
 | 
			
		||||
        "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
 | 
			
		||||
        "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
 | 
			
		||||
    }
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
@ -187,12 +209,14 @@ def convert_sharegpt(
 | 
			
		||||
        prompt, response = [], []
 | 
			
		||||
 | 
			
		||||
    convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
    convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
    output = {
 | 
			
		||||
        "_prompt": prompt,
 | 
			
		||||
        "_response": response,
 | 
			
		||||
        "_system": system,
 | 
			
		||||
        "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
 | 
			
		||||
        "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
 | 
			
		||||
        "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
 | 
			
		||||
    }
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
@ -210,6 +234,7 @@ def align_dataset(
 | 
			
		||||
        _system: "..."
 | 
			
		||||
        _tools: "...",
 | 
			
		||||
        _images: [],
 | 
			
		||||
        _videos: [],
 | 
			
		||||
    """
 | 
			
		||||
    if dataset_attr.formatting == "alpaca":
 | 
			
		||||
        convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
 | 
			
		||||
@ -79,14 +79,19 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
    processor: Optional["ProcessorMixin"] = None
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
        batch_images, batch_imglens, batch_seqlens = [], [], []
 | 
			
		||||
        batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
 | 
			
		||||
        for feature in features:
 | 
			
		||||
            images = feature.pop("images") or []  # avoid NoneType
 | 
			
		||||
            videos = feature.pop("videos") or []
 | 
			
		||||
            batch_images.extend(images)
 | 
			
		||||
            batch_videos.extend(videos)
 | 
			
		||||
            batch_imglens.append(len(images))
 | 
			
		||||
            batch_vidlens.append(len(videos))
 | 
			
		||||
            batch_seqlens.append(len(feature["input_ids"]))
 | 
			
		||||
 | 
			
		||||
        mm_inputs = self.template.mm_plugin.get_mm_inputs(batch_images, batch_imglens, batch_seqlens, self.processor)
 | 
			
		||||
        mm_inputs = self.template.mm_plugin.get_mm_inputs(
 | 
			
		||||
            batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
 | 
			
		||||
        )
 | 
			
		||||
        if "token_type_ids" in mm_inputs:
 | 
			
		||||
            token_type_ids = mm_inputs.pop("token_type_ids")
 | 
			
		||||
            for i, feature in enumerate(features):
 | 
			
		||||
@ -136,6 +141,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
                    "attention_mask": feature["{}_attention_mask".format(key)],
 | 
			
		||||
                    "labels": feature["{}_labels".format(key)],
 | 
			
		||||
                    "images": feature["images"],
 | 
			
		||||
                    "videos": feature["videos"],
 | 
			
		||||
                }
 | 
			
		||||
                concatenated_features.append(target_feature)
 | 
			
		||||
 | 
			
		||||
@ -158,12 +164,14 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
                "attention_mask": feature["attention_mask"],
 | 
			
		||||
                "labels": feature["labels"],
 | 
			
		||||
                "images": feature["images"],
 | 
			
		||||
                "videos": feature["videos"],
 | 
			
		||||
            }
 | 
			
		||||
            kl_feature = {
 | 
			
		||||
                "input_ids": feature["kl_input_ids"],
 | 
			
		||||
                "attention_mask": feature["kl_attention_mask"],
 | 
			
		||||
                "labels": feature["kl_labels"],
 | 
			
		||||
                "images": feature["images"],
 | 
			
		||||
                "videos": feature["videos"],
 | 
			
		||||
            }
 | 
			
		||||
            target_features.append(target_feature)
 | 
			
		||||
            kl_features.append(kl_feature)
 | 
			
		||||
 | 
			
		||||
@ -2,11 +2,10 @@ from copy import deepcopy
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
 | 
			
		||||
 | 
			
		||||
from PIL.Image import Image
 | 
			
		||||
from transformers import ProcessorMixin
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
 | 
			
		||||
from ..extras.packages import is_pillow_available
 | 
			
		||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.packages import is_pillow_available, is_pyav_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_pillow_available():
 | 
			
		||||
@ -14,8 +13,13 @@ if is_pillow_available():
 | 
			
		||||
    from PIL.Image import Image as ImageObject
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_pyav_available():
 | 
			
		||||
    import av
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    import torch
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
 | 
			
		||||
@ -24,13 +28,14 @@ if TYPE_CHECKING:
 | 
			
		||||
        bytes: Optional[bytes]
 | 
			
		||||
 | 
			
		||||
    ImageInput = Union[str, EncodedImage, ImageObject]
 | 
			
		||||
    VideoInput = str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Regularizes images to avoid error. Including reading, resizing and converting.
 | 
			
		||||
    """
 | 
			
		||||
    image_resolution = getattr(processor, "image_resolution", 512)
 | 
			
		||||
    image_resolution: int = getattr(processor, "image_resolution", 512)
 | 
			
		||||
    results = []
 | 
			
		||||
    for image in images:
 | 
			
		||||
        if isinstance(image, str):
 | 
			
		||||
@ -56,7 +61,37 @@ def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixi
 | 
			
		||||
    return results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixin") -> List["NDArray"]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Regularizes videos to avoid error. Including reading, resizing and converting.
 | 
			
		||||
    """
 | 
			
		||||
    video_fps: float = getattr(processor, "video_fps", 1.0)
 | 
			
		||||
    video_factor: int = getattr(processor, "video_factor", 1)
 | 
			
		||||
    results = []
 | 
			
		||||
    for video in videos:
 | 
			
		||||
        container = av.open(video, "r")
 | 
			
		||||
        video_stream = next(stream for stream in container.streams if stream.type == "video")
 | 
			
		||||
        total_frames = video_stream.frames
 | 
			
		||||
        sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
 | 
			
		||||
        sample_frames = round(sample_frames / video_factor) * video_factor  # for qwen2_vl
 | 
			
		||||
        sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
 | 
			
		||||
        frames: List["ImageObject"] = []
 | 
			
		||||
        container.seek(0)
 | 
			
		||||
        for frame_idx, frame in enumerate(container.decode(video_stream)):
 | 
			
		||||
            if frame_idx in sample_indices:
 | 
			
		||||
                frames.append(frame.to_image())
 | 
			
		||||
 | 
			
		||||
        frames = _regularize_images(frames, processor)
 | 
			
		||||
        results.append(frames)
 | 
			
		||||
 | 
			
		||||
    return results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_mm_inputs(
 | 
			
		||||
    images: Sequence["ImageInput"],
 | 
			
		||||
    videos: Sequence["VideoInput"],
 | 
			
		||||
    processor: "ProcessorMixin",
 | 
			
		||||
) -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Processes visual inputs.
 | 
			
		||||
 | 
			
		||||
@ -70,13 +105,19 @@ def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin")
 | 
			
		||||
    It holds num_patches == torch.prod(image_grid_thw)
 | 
			
		||||
    """
 | 
			
		||||
    image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
    input_dict = {"images": None, "videos": None}
 | 
			
		||||
    if len(images) != 0:
 | 
			
		||||
        images = _regularize_images(images, processor)
 | 
			
		||||
        image_inputs = image_processor(images=images, return_tensors="pt")
 | 
			
		||||
    else:
 | 
			
		||||
        image_inputs = {}
 | 
			
		||||
        input_dict["images"] = images
 | 
			
		||||
 | 
			
		||||
    return image_inputs
 | 
			
		||||
    if len(videos) != 0:
 | 
			
		||||
        videos = _regularize_videos(videos, processor)
 | 
			
		||||
        input_dict["videos"] = videos
 | 
			
		||||
 | 
			
		||||
    if input_dict["images"] is not None or input_dict["videos"] is not None:
 | 
			
		||||
        return image_processor(**input_dict, return_tensors="pt")
 | 
			
		||||
    else:
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_paligemma_token_type_ids(
 | 
			
		||||
@ -97,18 +138,32 @@ def _get_paligemma_token_type_ids(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BasePlugin:
 | 
			
		||||
    def __init__(self, image_token: str) -> None:
 | 
			
		||||
    def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
 | 
			
		||||
        self.image_token = image_token
 | 
			
		||||
        self.video_token = video_token
 | 
			
		||||
 | 
			
		||||
    def _validate_input(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        if len(images) != 0 and self.image_token is None:
 | 
			
		||||
            raise ValueError("This model does not support image input.")
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0 and self.video_token is None:
 | 
			
		||||
            raise ValueError("This model does not support video input.")
 | 
			
		||||
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Pre-processes input messages before tokenization for VLMs.
 | 
			
		||||
        """
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    def process_token_ids(
 | 
			
		||||
@ -116,24 +171,29 @@ class BasePlugin:
 | 
			
		||||
        input_ids: List[int],
 | 
			
		||||
        labels: Optional[List[int]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Tuple[List[int], Optional[List[int]]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Pre-processes token ids after tokenization for VLMs.
 | 
			
		||||
        """
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        return input_ids, labels
 | 
			
		||||
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        seqlens: Sequence[int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Builds batched multimodal inputs for VLMs.
 | 
			
		||||
        """
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -142,8 +202,10 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        image_seqlen = getattr(processor, "image_seqlen")
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
@ -163,11 +225,14 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        seqlens: Sequence[int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        return _get_mm_inputs(images, processor)
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        return _get_mm_inputs(images, videos, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
@ -175,8 +240,10 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        for message in messages:
 | 
			
		||||
@ -197,9 +264,11 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        input_ids: List[int],
 | 
			
		||||
        labels: Optional[List[int]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Tuple[List[int], Optional[List[int]]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        num_images = len(images)
 | 
			
		||||
        image_seqlen = num_images * getattr(processor, "image_seqlen")
 | 
			
		||||
        image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
 | 
			
		||||
@ -212,11 +281,14 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        seqlens: Sequence[int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        mm_inputs = _get_mm_inputs(images, processor)
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        mm_inputs = _get_mm_inputs(images, videos, processor)
 | 
			
		||||
        mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
@ -226,16 +298,17 @@ class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
        merge_length: int = getattr(image_processor, "merge_size") ** 2
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"]
 | 
			
		||||
        else:
 | 
			
		||||
            image_grid_thw = []
 | 
			
		||||
        mm_inputs = _get_mm_inputs(images, videos, processor)
 | 
			
		||||
        image_grid_thw = mm_inputs.get("image_grid_thw", [])
 | 
			
		||||
        video_grid_thw = mm_inputs.get("video_grid_thw", [])
 | 
			
		||||
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        num_image_tokens, num_video_tokens = 0, 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
@ -252,21 +325,40 @@ class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
                )
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                if num_video_tokens >= len(video_grid_thw):
 | 
			
		||||
                    raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
 | 
			
		||||
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    VIDEO_PLACEHOLDER,
 | 
			
		||||
                    "<|vision_start|>{}<|vision_end|>".format(
 | 
			
		||||
                        self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
 | 
			
		||||
                    ),
 | 
			
		||||
                    1,
 | 
			
		||||
                )
 | 
			
		||||
                num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
 | 
			
		||||
 | 
			
		||||
        if len(videos) != num_video_tokens:
 | 
			
		||||
            raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        seqlens: Sequence[int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        return _get_mm_inputs(images, processor)
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        return _get_mm_inputs(images, videos, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PLUGINS = {
 | 
			
		||||
@ -277,9 +369,13 @@ PLUGINS = {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
 | 
			
		||||
def get_mm_plugin(
 | 
			
		||||
    name: str,
 | 
			
		||||
    image_token: Optional[str] = None,
 | 
			
		||||
    video_token: Optional[str] = None,
 | 
			
		||||
) -> "BasePlugin":
 | 
			
		||||
    plugin_class = PLUGINS.get(name, None)
 | 
			
		||||
    if plugin_class is None:
 | 
			
		||||
        raise ValueError("Multimodal plugin `{}` not found.".format(name))
 | 
			
		||||
 | 
			
		||||
    return plugin_class(image_token)
 | 
			
		||||
    return plugin_class(image_token, video_token)
 | 
			
		||||
 | 
			
		||||
@ -43,6 +43,7 @@ class DatasetAttr:
 | 
			
		||||
    system: Optional[str] = None
 | 
			
		||||
    tools: Optional[str] = None
 | 
			
		||||
    images: Optional[str] = None
 | 
			
		||||
    videos: Optional[str] = None
 | 
			
		||||
    # rlhf columns
 | 
			
		||||
    chosen: Optional[str] = None
 | 
			
		||||
    rejected: Optional[str] = None
 | 
			
		||||
@ -126,7 +127,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
 | 
			
		||||
        dataset_attr.set_attr("num_samples", dataset_info[name])
 | 
			
		||||
 | 
			
		||||
        if "columns" in dataset_info[name]:
 | 
			
		||||
            column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
 | 
			
		||||
            column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
 | 
			
		||||
            if dataset_attr.formatting == "alpaca":
 | 
			
		||||
                column_names.extend(["prompt", "query", "response", "history"])
 | 
			
		||||
            else:
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..mm_plugin import ImageInput
 | 
			
		||||
    from ..mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,6 +38,7 @@ def _encode_feedback_example(
 | 
			
		||||
    system: Optional[str],
 | 
			
		||||
    tools: Optional[str],
 | 
			
		||||
    images: Sequence["ImageInput"],
 | 
			
		||||
    videos: Sequence["VideoInput"],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
@ -55,8 +56,8 @@ def _encode_feedback_example(
 | 
			
		||||
    else:
 | 
			
		||||
        kl_messages = prompt + [kl_response[1]]
 | 
			
		||||
 | 
			
		||||
    messages = template.mm_plugin.process_messages(messages, images, processor)
 | 
			
		||||
    kl_messages = template.mm_plugin.process_messages(kl_messages, images, processor)
 | 
			
		||||
    messages = template.mm_plugin.process_messages(messages, images, videos, processor)
 | 
			
		||||
    kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor)
 | 
			
		||||
    prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
 | 
			
		||||
    kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
 | 
			
		||||
 | 
			
		||||
@ -64,8 +65,8 @@ def _encode_feedback_example(
 | 
			
		||||
        response_ids += [tokenizer.eos_token_id]
 | 
			
		||||
        kl_response_ids += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
    prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
 | 
			
		||||
    kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, tokenizer, processor)
 | 
			
		||||
    prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
 | 
			
		||||
    kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor)
 | 
			
		||||
 | 
			
		||||
    source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
 | 
			
		||||
    prompt_ids = prompt_ids[:source_len]
 | 
			
		||||
@ -103,6 +104,7 @@ def preprocess_feedback_dataset(
 | 
			
		||||
            system=examples["_system"][i],
 | 
			
		||||
            tools=examples["_tools"][i],
 | 
			
		||||
            images=examples["_images"][i] or [],
 | 
			
		||||
            videos=examples["_videos"][i] or [],
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
@ -116,6 +118,7 @@ def preprocess_feedback_dataset(
 | 
			
		||||
        model_inputs["kl_labels"].append(kl_labels)
 | 
			
		||||
        model_inputs["kto_tags"].append(kto_tag)
 | 
			
		||||
        model_inputs["images"].append(examples["_images"][i])
 | 
			
		||||
        model_inputs["videos"].append(examples["_videos"][i])
 | 
			
		||||
 | 
			
		||||
    desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
 | 
			
		||||
    undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..mm_plugin import ImageInput
 | 
			
		||||
    from ..mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -37,13 +37,14 @@ def _encode_pairwise_example(
 | 
			
		||||
    system: Optional[str],
 | 
			
		||||
    tools: Optional[str],
 | 
			
		||||
    images: Sequence["ImageInput"],
 | 
			
		||||
    videos: Sequence["VideoInput"],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    cutoff_len: int,
 | 
			
		||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
 | 
			
		||||
    chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor)
 | 
			
		||||
    rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor)
 | 
			
		||||
    chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor)
 | 
			
		||||
    rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor)
 | 
			
		||||
    prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
 | 
			
		||||
    _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
 | 
			
		||||
 | 
			
		||||
@ -51,7 +52,7 @@ def _encode_pairwise_example(
 | 
			
		||||
        chosen_ids += [tokenizer.eos_token_id]
 | 
			
		||||
        rejected_ids += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
    prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
 | 
			
		||||
    prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
 | 
			
		||||
    # consider the response is more important
 | 
			
		||||
    source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
 | 
			
		||||
    prompt_ids = prompt_ids[:source_len]
 | 
			
		||||
@ -85,6 +86,7 @@ def preprocess_pairwise_dataset(
 | 
			
		||||
            system=examples["_system"][i],
 | 
			
		||||
            tools=examples["_tools"][i],
 | 
			
		||||
            images=examples["_images"][i] or [],
 | 
			
		||||
            videos=examples["_videos"][i] or [],
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
@ -97,6 +99,7 @@ def preprocess_pairwise_dataset(
 | 
			
		||||
        model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
 | 
			
		||||
        model_inputs["rejected_labels"].append(rejected_labels)
 | 
			
		||||
        model_inputs["images"].append(examples["_images"][i])
 | 
			
		||||
        model_inputs["videos"].append(examples["_videos"][i])
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..mm_plugin import ImageInput
 | 
			
		||||
    from ..mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -37,6 +37,7 @@ def _encode_supervised_example(
 | 
			
		||||
    system: Optional[str],
 | 
			
		||||
    tools: Optional[str],
 | 
			
		||||
    images: Sequence["ImageInput"],
 | 
			
		||||
    videos: Sequence["VideoInput"],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
@ -44,8 +45,8 @@ def _encode_supervised_example(
 | 
			
		||||
    train_on_prompt: bool,
 | 
			
		||||
    mask_history: bool,
 | 
			
		||||
) -> Tuple[List[int], List[int]]:
 | 
			
		||||
    messages = template.mm_plugin.process_messages(prompt + response, images, processor)
 | 
			
		||||
    input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor)
 | 
			
		||||
    messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
 | 
			
		||||
    input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
 | 
			
		||||
    encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
 | 
			
		||||
    total_length = len(input_ids) + (1 if template.efficient_eos else 0)
 | 
			
		||||
    if mask_history:
 | 
			
		||||
@ -107,6 +108,7 @@ def preprocess_supervised_dataset(
 | 
			
		||||
            system=examples["_system"][i],
 | 
			
		||||
            tools=examples["_tools"][i],
 | 
			
		||||
            images=examples["_images"][i] or [],
 | 
			
		||||
            videos=examples["_videos"][i] or [],
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
@ -118,6 +120,7 @@ def preprocess_supervised_dataset(
 | 
			
		||||
        model_inputs["attention_mask"].append([1] * len(input_ids))
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        model_inputs["images"].append(examples["_images"][i])
 | 
			
		||||
        model_inputs["videos"].append(examples["_videos"][i])
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
@ -132,11 +135,8 @@ def preprocess_packed_supervised_dataset(
 | 
			
		||||
    # TODO: use `position_ids` to achieve packing
 | 
			
		||||
    # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
 | 
			
		||||
    # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        raise NotImplementedError("`packing` have not been implemented for multimodal datasets.")
 | 
			
		||||
 | 
			
		||||
    valid_num = 0
 | 
			
		||||
    batch_input_ids, batch_labels = [], []
 | 
			
		||||
    batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
 | 
			
		||||
    lengths = []
 | 
			
		||||
    length2indexes = defaultdict(list)
 | 
			
		||||
    for i in range(len(examples["_prompt"])):
 | 
			
		||||
@ -150,9 +150,10 @@ def preprocess_packed_supervised_dataset(
 | 
			
		||||
            system=examples["_system"][i],
 | 
			
		||||
            tools=examples["_tools"][i],
 | 
			
		||||
            images=examples["_images"][i] or [],
 | 
			
		||||
            videos=examples["_videos"][i] or [],
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=None,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
            cutoff_len=data_args.cutoff_len - 1,  # reserved for the padding token
 | 
			
		||||
            train_on_prompt=data_args.train_on_prompt,
 | 
			
		||||
            mask_history=data_args.mask_history,
 | 
			
		||||
@ -165,16 +166,21 @@ def preprocess_packed_supervised_dataset(
 | 
			
		||||
            length2indexes[length].append(valid_num)
 | 
			
		||||
            batch_input_ids.append(input_ids)
 | 
			
		||||
            batch_labels.append(labels)
 | 
			
		||||
            batch_images.append(examples["_images"][i] or [])
 | 
			
		||||
            batch_videos.append(examples["_videos"][i] or [])
 | 
			
		||||
            valid_num += 1
 | 
			
		||||
 | 
			
		||||
    model_inputs = defaultdict(list)
 | 
			
		||||
    knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1)  # reserved for the padding token
 | 
			
		||||
    for knapsack in knapsacks:
 | 
			
		||||
        packed_input_ids, packed_attention_masks, packed_labels = [], [], []
 | 
			
		||||
        packed_images, packed_videos = [], []
 | 
			
		||||
        for i, length in enumerate(knapsack):
 | 
			
		||||
            index = length2indexes[length].pop()
 | 
			
		||||
            packed_input_ids += batch_input_ids[index]
 | 
			
		||||
            packed_labels += batch_labels[index]
 | 
			
		||||
            packed_images += batch_images[index]
 | 
			
		||||
            packed_videos += batch_videos[index]
 | 
			
		||||
            if data_args.neat_packing:
 | 
			
		||||
                packed_attention_masks += [i + 1] * len(batch_input_ids[index])  # start from 1
 | 
			
		||||
            else:
 | 
			
		||||
@ -195,7 +201,8 @@ def preprocess_packed_supervised_dataset(
 | 
			
		||||
        model_inputs["input_ids"].append(packed_input_ids)
 | 
			
		||||
        model_inputs["attention_mask"].append(packed_attention_masks)
 | 
			
		||||
        model_inputs["labels"].append(packed_labels)
 | 
			
		||||
        model_inputs["images"].append(examples["_images"][i])
 | 
			
		||||
        model_inputs["images"].append(packed_images or None)
 | 
			
		||||
        model_inputs["videos"].append(packed_videos or None)
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..mm_plugin import ImageInput
 | 
			
		||||
    from ..mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -37,6 +37,7 @@ def _encode_unsupervised_example(
 | 
			
		||||
    system: Optional[str],
 | 
			
		||||
    tools: Optional[str],
 | 
			
		||||
    images: Sequence["ImageInput"],
 | 
			
		||||
    videos: Sequence["VideoInput"],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
@ -47,12 +48,12 @@ def _encode_unsupervised_example(
 | 
			
		||||
    else:
 | 
			
		||||
        messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
 | 
			
		||||
 | 
			
		||||
    messages = template.mm_plugin.process_messages(messages, images, processor)
 | 
			
		||||
    messages = template.mm_plugin.process_messages(messages, images, videos, processor)
 | 
			
		||||
    input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
 | 
			
		||||
    if template.efficient_eos:
 | 
			
		||||
        labels += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
    input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, tokenizer, processor)
 | 
			
		||||
    input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
 | 
			
		||||
    source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
 | 
			
		||||
    input_ids = input_ids[:source_len]
 | 
			
		||||
    labels = labels[:target_len]
 | 
			
		||||
@ -79,6 +80,7 @@ def preprocess_unsupervised_dataset(
 | 
			
		||||
            system=examples["_system"][i],
 | 
			
		||||
            tools=examples["_tools"][i],
 | 
			
		||||
            images=examples["_images"][i] or [],
 | 
			
		||||
            videos=examples["_videos"][i] or [],
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
@ -88,6 +90,7 @@ def preprocess_unsupervised_dataset(
 | 
			
		||||
        model_inputs["attention_mask"].append([1] * len(input_ids))
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        model_inputs["images"].append(examples["_images"][i])
 | 
			
		||||
        model_inputs["videos"].append(examples["_videos"][i])
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
from .data_utils import Role
 | 
			
		||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
 | 
			
		||||
@ -213,7 +212,7 @@ def _register_template(
 | 
			
		||||
    stop_words: Sequence[str] = [],
 | 
			
		||||
    efficient_eos: bool = False,
 | 
			
		||||
    replace_eos: bool = False,
 | 
			
		||||
    mm_plugin: "BasePlugin" = get_mm_plugin(name="base", image_token=IMAGE_PLACEHOLDER),
 | 
			
		||||
    mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
 | 
			
		||||
) -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Registers a chat template.
 | 
			
		||||
@ -826,7 +825,7 @@ _register_template(
 | 
			
		||||
    default_system="You are a helpful assistant.",
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
    replace_eos=True,
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>"),
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -95,6 +95,8 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
 | 
			
		||||
 | 
			
		||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
 | 
			
		||||
 | 
			
		||||
VIDEO_PLACEHOLDER = "<video>"
 | 
			
		||||
 | 
			
		||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
 | 
			
		||||
 | 
			
		||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
 | 
			
		||||
 | 
			
		||||
@ -38,6 +38,10 @@ def _get_package_version(name: str) -> "Version":
 | 
			
		||||
        return version.parse("0.0.0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_pyav_available():
 | 
			
		||||
    return _is_package_available("av")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_fastapi_available():
 | 
			
		||||
    return _is_package_available("fastapi")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -142,6 +142,10 @@ class ModelArguments:
 | 
			
		||||
        default=512,
 | 
			
		||||
        metadata={"help": "Keeps the height or width of image below this resolution."},
 | 
			
		||||
    )
 | 
			
		||||
    video_fps: float = field(
 | 
			
		||||
        default=2.0,
 | 
			
		||||
        metadata={"help": "The frames to sample per second for video training."},
 | 
			
		||||
    )
 | 
			
		||||
    infer_backend: Literal["huggingface", "vllm"] = field(
 | 
			
		||||
        default="huggingface",
 | 
			
		||||
        metadata={"help": "Backend engine used at inference."},
 | 
			
		||||
 | 
			
		||||
@ -100,6 +100,11 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
 | 
			
		||||
        setattr(processor, "tokenizer", tokenizer)
 | 
			
		||||
        setattr(processor, "image_seqlen", get_image_seqlen(config))
 | 
			
		||||
        setattr(processor, "image_resolution", model_args.image_resolution)
 | 
			
		||||
        setattr(processor, "video_fps", model_args.video_fps)
 | 
			
		||||
        if getattr(config, "model_type", None) == "qwen2_vl":
 | 
			
		||||
            setattr(processor, "video_factor", 2)
 | 
			
		||||
        else:
 | 
			
		||||
            setattr(processor, "video_factor", 1)
 | 
			
		||||
    except Exception:
 | 
			
		||||
        processor = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -133,6 +133,7 @@ class WebChatModel(ChatModel):
 | 
			
		||||
        system: str,
 | 
			
		||||
        tools: str,
 | 
			
		||||
        image: Optional[Any],
 | 
			
		||||
        video: Optional[Any],
 | 
			
		||||
        max_new_tokens: int,
 | 
			
		||||
        top_p: float,
 | 
			
		||||
        temperature: float,
 | 
			
		||||
@ -140,7 +141,7 @@ class WebChatModel(ChatModel):
 | 
			
		||||
        chatbot[-1][1] = ""
 | 
			
		||||
        response = ""
 | 
			
		||||
        for new_text in self.stream_chat(
 | 
			
		||||
            messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
 | 
			
		||||
            messages, system, tools, image, video, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
 | 
			
		||||
        ):
 | 
			
		||||
            response += new_text
 | 
			
		||||
            if tools:
 | 
			
		||||
 | 
			
		||||
@ -43,8 +43,12 @@ def create_chat_box(
 | 
			
		||||
                        system = gr.Textbox(show_label=False)
 | 
			
		||||
                        tools = gr.Textbox(show_label=False, lines=3)
 | 
			
		||||
 | 
			
		||||
                    with gr.Column() as image_box:
 | 
			
		||||
                        image = gr.Image(sources=["upload"], type="pil")
 | 
			
		||||
                    with gr.Column() as mm_box:
 | 
			
		||||
                        with gr.Tab("Image"):
 | 
			
		||||
                            image = gr.Image(sources=["upload"], type="pil")
 | 
			
		||||
 | 
			
		||||
                        with gr.Tab("Video"):
 | 
			
		||||
                            video = gr.Video(sources=["upload"])
 | 
			
		||||
 | 
			
		||||
                query = gr.Textbox(show_label=False, lines=8)
 | 
			
		||||
                submit_btn = gr.Button(variant="primary")
 | 
			
		||||
@ -63,7 +67,7 @@ def create_chat_box(
 | 
			
		||||
        [chatbot, messages, query],
 | 
			
		||||
    ).then(
 | 
			
		||||
        engine.chatter.stream,
 | 
			
		||||
        [chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
 | 
			
		||||
        [chatbot, messages, system, tools, image, video, max_new_tokens, top_p, temperature],
 | 
			
		||||
        [chatbot, messages],
 | 
			
		||||
    )
 | 
			
		||||
    clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
 | 
			
		||||
@ -76,8 +80,9 @@ def create_chat_box(
 | 
			
		||||
            role=role,
 | 
			
		||||
            system=system,
 | 
			
		||||
            tools=tools,
 | 
			
		||||
            image_box=image_box,
 | 
			
		||||
            mm_box=mm_box,
 | 
			
		||||
            image=image,
 | 
			
		||||
            video=video,
 | 
			
		||||
            query=query,
 | 
			
		||||
            submit_btn=submit_btn,
 | 
			
		||||
            max_new_tokens=max_new_tokens,
 | 
			
		||||
 | 
			
		||||
@ -68,7 +68,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
 | 
			
		||||
    engine.manager.get_elem_by_id("top.model_name").change(
 | 
			
		||||
        lambda model_name: gr.Column(visible=get_visual(model_name)),
 | 
			
		||||
        [engine.manager.get_elem_by_id("top.model_name")],
 | 
			
		||||
        [chat_elems["image_box"]],
 | 
			
		||||
        [chat_elems["mm_box"]],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return elem_dict
 | 
			
		||||
 | 
			
		||||
@ -59,7 +59,7 @@ class Engine:
 | 
			
		||||
            init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
 | 
			
		||||
            init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
 | 
			
		||||
            init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
 | 
			
		||||
            init_dict["infer.image_box"] = {"visible": False}
 | 
			
		||||
            init_dict["infer.mm_box"] = {"visible": False}
 | 
			
		||||
 | 
			
		||||
            if user_config.get("last_model", None):
 | 
			
		||||
                init_dict["top.model_name"] = {"value": user_config["last_model"]}
 | 
			
		||||
 | 
			
		||||
@ -1691,6 +1691,20 @@ LOCALES = {
 | 
			
		||||
            "label": "이미지 (선택 사항)",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "video": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Video (optional)",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Видео (по желанию)",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "视频(非必填)",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "비디오 (선택 사항)",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "query": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "placeholder": "Input...",
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,7 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
@ -28,6 +28,8 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
 | 
			
		||||
    from llamafactory.data.mm_plugin import BasePlugin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
HF_TOKEN = os.environ.get("HF_TOKEN", None)
 | 
			
		||||
 | 
			
		||||
@ -47,10 +49,14 @@ IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
 | 
			
		||||
 | 
			
		||||
NO_IMAGES = []
 | 
			
		||||
 | 
			
		||||
NO_VIDEOS = []
 | 
			
		||||
 | 
			
		||||
IMGLENS = [1]
 | 
			
		||||
 | 
			
		||||
NO_IMGLENS = [0]
 | 
			
		||||
 | 
			
		||||
NO_VIDLENS = [0]
 | 
			
		||||
 | 
			
		||||
INPUT_IDS = [0, 1, 2, 3, 4]
 | 
			
		||||
 | 
			
		||||
LABELS = [0, 1, 2, 3, 4]
 | 
			
		||||
@ -78,92 +84,97 @@ def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenize
 | 
			
		||||
    return tokenizer_module["tokenizer"], tokenizer_module["processor"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _check_plugin(
 | 
			
		||||
    plugin: "BasePlugin",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: "ProcessorMixin",
 | 
			
		||||
    expected_mm_messages: Sequence[Dict[str, str]],
 | 
			
		||||
    expected_input_ids: List[int],
 | 
			
		||||
    expected_labels: List[int],
 | 
			
		||||
    expected_mm_inputs: Dict[str, Any],
 | 
			
		||||
    expected_no_mm_inputs: Dict[str, Any],
 | 
			
		||||
) -> None:
 | 
			
		||||
    # test mm_messages
 | 
			
		||||
    assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
 | 
			
		||||
    assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, tokenizer, processor) == (
 | 
			
		||||
        expected_input_ids,
 | 
			
		||||
        expected_labels,
 | 
			
		||||
    )
 | 
			
		||||
    _is_close(
 | 
			
		||||
        plugin.get_mm_inputs(IMAGES, NO_VIDEOS, IMGLENS, NO_VIDLENS, SEQLENS, processor),
 | 
			
		||||
        expected_mm_inputs,
 | 
			
		||||
    )
 | 
			
		||||
    # test text_messages
 | 
			
		||||
    assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, processor) == TEXT_MESSAGES
 | 
			
		||||
    assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, tokenizer, processor) == (
 | 
			
		||||
        INPUT_IDS,
 | 
			
		||||
        LABELS,
 | 
			
		||||
    )
 | 
			
		||||
    _is_close(
 | 
			
		||||
        plugin.get_mm_inputs(NO_IMAGES, NO_VIDEOS, NO_IMGLENS, NO_VIDLENS, SEQLENS, processor),
 | 
			
		||||
        expected_no_mm_inputs,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_base_plugin():
 | 
			
		||||
    tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
 | 
			
		||||
    base_plugin = get_mm_plugin(name="base", image_token="<image>")
 | 
			
		||||
    # test mm_messages
 | 
			
		||||
    assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES
 | 
			
		||||
    assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
 | 
			
		||||
    _is_close(base_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), {})
 | 
			
		||||
    # test text_messages
 | 
			
		||||
    assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
 | 
			
		||||
    assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
 | 
			
		||||
    _is_close(base_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
 | 
			
		||||
    check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
 | 
			
		||||
    check_inputs["expected_mm_messages"] = MM_MESSAGES
 | 
			
		||||
    check_inputs["expected_input_ids"] = INPUT_IDS
 | 
			
		||||
    check_inputs["expected_labels"] = LABELS
 | 
			
		||||
    check_inputs["expected_mm_inputs"] = {}
 | 
			
		||||
    check_inputs["expected_no_mm_inputs"] = {}
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_llava_plugin():
 | 
			
		||||
    tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
 | 
			
		||||
    llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
 | 
			
		||||
    image_seqlen = 576
 | 
			
		||||
 | 
			
		||||
    mm_inputs = _get_mm_inputs(processor)
 | 
			
		||||
    expected_mm_messages = [
 | 
			
		||||
    check_inputs = {"plugin": llava_plugin, "tokenizer": tokenizer, "processor": processor}
 | 
			
		||||
    check_inputs["expected_mm_messages"] = [
 | 
			
		||||
        {key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
 | 
			
		||||
        for message in MM_MESSAGES
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
 | 
			
		||||
    # test mm_messages
 | 
			
		||||
    assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
 | 
			
		||||
    assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
 | 
			
		||||
    _is_close(llava_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
 | 
			
		||||
    # test text_messages
 | 
			
		||||
    assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
 | 
			
		||||
    assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
 | 
			
		||||
    _is_close(llava_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
 | 
			
		||||
    check_inputs["expected_input_ids"] = INPUT_IDS
 | 
			
		||||
    check_inputs["expected_labels"] = LABELS
 | 
			
		||||
    check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
 | 
			
		||||
    check_inputs["expected_no_mm_inputs"] = {}
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
 | 
			
		||||
def test_paligemma_plugin():
 | 
			
		||||
    tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
 | 
			
		||||
    paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
 | 
			
		||||
    image_seqlen = 256
 | 
			
		||||
 | 
			
		||||
    mm_inputs = _get_mm_inputs(processor)
 | 
			
		||||
    mm_inputs["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
 | 
			
		||||
    expected_mm_messages = [
 | 
			
		||||
    check_inputs = {"plugin": paligemma_plugin, "tokenizer": tokenizer, "processor": processor}
 | 
			
		||||
    check_inputs["expected_mm_messages"] = [
 | 
			
		||||
        {key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
 | 
			
		||||
    ]
 | 
			
		||||
    expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
 | 
			
		||||
    expected_labels = [-100] * image_seqlen + LABELS
 | 
			
		||||
 | 
			
		||||
    paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
 | 
			
		||||
    # test mm_messages
 | 
			
		||||
    assert paligemma_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
 | 
			
		||||
    assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (
 | 
			
		||||
        expected_input_ids,
 | 
			
		||||
        expected_labels,
 | 
			
		||||
    )
 | 
			
		||||
    _is_close(paligemma_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
 | 
			
		||||
    # test text_messages
 | 
			
		||||
    assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
 | 
			
		||||
    assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
 | 
			
		||||
        INPUT_IDS,
 | 
			
		||||
        LABELS,
 | 
			
		||||
    )
 | 
			
		||||
    _is_close(
 | 
			
		||||
        paligemma_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor),
 | 
			
		||||
        {"token_type_ids": [[1] * 1024]},
 | 
			
		||||
    )
 | 
			
		||||
    check_inputs["expected_input_ids"] = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
 | 
			
		||||
    check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
 | 
			
		||||
    check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
 | 
			
		||||
    check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
 | 
			
		||||
    check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_qwen2_vl_plugin():
 | 
			
		||||
    tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
 | 
			
		||||
    qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
 | 
			
		||||
    image_seqlen = 4
 | 
			
		||||
 | 
			
		||||
    mm_inputs = _get_mm_inputs(processor)
 | 
			
		||||
    expected_mm_messages = [
 | 
			
		||||
    check_inputs = {"plugin": qwen2_vl_plugin, "tokenizer": tokenizer, "processor": processor}
 | 
			
		||||
    check_inputs["expected_mm_messages"] = [
 | 
			
		||||
        {
 | 
			
		||||
            key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
 | 
			
		||||
            for key, value in message.items()
 | 
			
		||||
        }
 | 
			
		||||
        for message in MM_MESSAGES
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
 | 
			
		||||
    # test mm_messages
 | 
			
		||||
    assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
 | 
			
		||||
    assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
 | 
			
		||||
    _is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
 | 
			
		||||
    # test text_messages
 | 
			
		||||
    assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
 | 
			
		||||
    assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
 | 
			
		||||
    _is_close(qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
 | 
			
		||||
    check_inputs["expected_input_ids"] = INPUT_IDS
 | 
			
		||||
    check_inputs["expected_labels"] = LABELS
 | 
			
		||||
    check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
 | 
			
		||||
    check_inputs["expected_no_mm_inputs"] = {}
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user