diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index fde2f568..bb98bf5c 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -33,7 +33,7 @@ Dependency graph: transformers>=4.41.2,<=4.44.3 """ -from .cli import VERSION +from .extras.env import VERSION __version__ = VERSION diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index dabfca2a..7be6c2ef 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -89,11 +89,9 @@ class HuggingfaceEngine(BaseEngine): paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or generating_args["default_system"] - prompt_ids, _ = template.encode_oneturn( - tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools - ) + prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools) if image is not None: - prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor) + prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, [image], tokenizer, processor) prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 26f2896e..9220230f 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -124,9 +124,7 @@ class VllmEngine(BaseEngine): paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or self.generating_args["default_system"] - prompt_ids, _ = self.template.encode_oneturn( - tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools - ) + prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) if self.processor is not None and image is not None: # add image features image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index 9bfd9708..ea1a02f2 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -13,8 +13,8 @@ # limitations under the License. from .collator import ( - CustomDataCollatorForSeq2Seq, KTODataCollatorWithPadding, + MultiModalDataCollatorForSeq2Seq, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask, ) @@ -24,8 +24,8 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer __all__ = [ - "CustomDataCollatorForSeq2Seq", "KTODataCollatorWithPadding", + "MultiModalDataCollatorForSeq2Seq", "PairwiseDataCollatorWithPadding", "SFTDataCollatorWith4DAttentionMask", "Role", diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 29bbc9eb..5e67899b 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -62,44 +62,49 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype @dataclass -class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): +class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): r""" - Data collator for custom models (like Qwen2-VL). + Data collator that supports VLMs. """ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: - image_grid_thw = None # TODO: better handle various VLMs - if "image_grid_thw" in features[0]: - image_grid_thw_list = [ - torch.Tensor(feature["image_grid_thw"]).long() - for feature in features - if feature["image_grid_thw"][0][0] > 0 - ] - pixel_values_list = [ - torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0 - ] - if image_grid_thw_list: - image_grid_thw = torch.cat(image_grid_thw_list, dim=0) - pixel_values = torch.cat(pixel_values_list, dim=0) - else: - image_grid_thw = None - pixel_values = None + if "token_type_ids" in features[0].keys(): + for feature in features: + feature["token_type_ids"] = feature["token_type_ids"][0] - features = [ - {key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]} - for feature in features - ] + extra_features = {} + if "pixel_values" in features[0].keys(): + pixel_values = [] + for feature in features: + if feature["pixel_values"] is None: + pixel_values.append(torch.zeros(0, dtype=torch.float)) + else: + pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float)) - features = super().__call__(features) - if image_grid_thw is not None: - features["image_grid_thw"] = image_grid_thw - features["pixel_values"] = pixel_values + extra_features["pixel_values"] = torch.cat(pixel_values, dim=0) + if extra_features["pixel_values"].numel() == 0: + extra_features["pixel_values"] = None + if "image_grid_thw" in features[0].keys(): + image_grid_thw = [] + for feature in features: + if feature["image_grid_thw"] is None: + image_grid_thw.append(torch.zeros(0, dtype=torch.long)) + else: + image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long)) + + extra_features["image_grid_thw"] = torch.cat(pixel_values, dim=0) + if extra_features["image_grid_thw"].numel() == 0: + extra_features["image_grid_thw"] = None + + features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features] + features: Dict[str, "torch.Tensor"] = super().__call__(features) + features.update({key: value for key, value in extra_features.items() if value is not None}) return features @dataclass -class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq): +class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): r""" Data collator for 4d attention mask. """ @@ -117,7 +122,7 @@ class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq): @dataclass -class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq): +class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): r""" Data collator for pairwise data. """ @@ -152,7 +157,7 @@ class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq): @dataclass -class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq): +class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): r""" Data collator for KTO data. """ diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index c1653a76..71b40789 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -16,16 +16,16 @@ import json import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union from .data_utils import SLOTS -from .tool_utils import DefaultToolUtils, GLM4ToolUtils +from .tool_utils import get_tool_utils @dataclass class Formatter(ABC): slots: SLOTS = field(default_factory=list) - tool_format: Optional[Literal["default", "glm4"]] = None + tool_format: Optional[str] = None @abstractmethod def apply(self, **kwargs) -> SLOTS: ... @@ -81,12 +81,7 @@ class StringFormatter(Formatter): @dataclass class FunctionFormatter(Formatter): def __post_init__(self): - if self.tool_format == "default": - self.slots = DefaultToolUtils.get_function_slots() + self.slots - elif self.tool_format == "glm4": - self.slots = GLM4ToolUtils.get_function_slots() + self.slots - else: - raise NotImplementedError("Tool format {} was not found.".format(self.tool_format)) + self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") @@ -119,22 +114,15 @@ class FunctionFormatter(Formatter): @dataclass class ToolFormatter(Formatter): def __post_init__(self): - if self.tool_format == "default": - self._tool_formatter = DefaultToolUtils.tool_formatter - self._tool_extractor = DefaultToolUtils.tool_extractor - elif self.tool_format == "glm4": - self._tool_formatter = GLM4ToolUtils.tool_formatter - self._tool_extractor = GLM4ToolUtils.tool_extractor - else: - raise NotImplementedError("Tool format {} was not found.".format(self.tool_format)) + self.tool_utils = get_tool_utils(self.tool_format) def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") try: tools = json.loads(content) - return [self._tool_formatter(tools) if len(tools) != 0 else ""] + return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] except json.JSONDecodeError: return [""] def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: - return self._tool_extractor(content) + return self.tool_utils.tool_extractor(content) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 7631c937..acd81ca0 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from PIL.Image import Image @@ -27,32 +28,33 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") Returns: (qwen2-vl) pixel_values: tensor with shape (num_patches, patch_dim) - image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height + image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height It holds num_patches == torch.prod(image_grid_thw) """ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") if len(images) != 0: image_inputs = image_processor(images=images, return_tensors="pt") - else: - image = Image.new("RGB", (56, 56), (255, 255, 255)) + else: # add NoneType for fake images + image = Image.new("RGB", (64, 64), (255, 255, 255)) image_inputs = image_processor(images=[image], return_tensors="pt") - if "image_grid_thw" in image_inputs: # fake image for qwen2-vl - image_inputs["image_grid_thw"][0][0] = 0 + image_inputs = {key: None for key in image_inputs.keys()} return image_inputs -def _get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]: +def _get_paligemma_token_type_ids( + images: Sequence["ImageObject"], input_len: int, processor: "ProcessorMixin" +) -> List[List[int]]: r""" Gets paligemma token type ids for computing loss. Returns: token_type_ids: shape (1, seq_len) """ - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - image_seq_length: int = getattr(image_processor, "image_seq_length") - return [[0] * image_seq_length + [1] * (input_len - image_seq_length)] + num_images = len(images) + image_seqlen = num_images * getattr(processor, "image_seqlen") + return [[0] * image_seqlen + [1] * (input_len - image_seqlen)] class BasePlugin: @@ -74,6 +76,7 @@ class BasePlugin: self, input_ids: List[int], labels: Optional[List[int]], + images: Sequence["ImageObject"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], ) -> Tuple[List[int], Optional[List[int]]]: @@ -93,18 +96,6 @@ class BasePlugin: """ return {} - def process_model_inputs( - self, - model_inputs: Dict[str, List[Any]], - images: Sequence["ImageObject"], - feature_seqlens: Dict[str, int], - processor: Optional["ProcessorMixin"], - ) -> None: - r""" - Appends multimodal inputs to model inputs for VLMs. - """ - return - class LlavaPlugin(BasePlugin): def process_messages( @@ -113,21 +104,21 @@ class LlavaPlugin(BasePlugin): images: Sequence["ImageObject"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - image_count = 0 - new_messages = [] + num_images = 0 + image_seqlen = getattr(processor, "image_seqlen") + messages = deepcopy(messages) for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - image_count += 1 - if image_count > 1: - raise ValueError("Llava model only accepts one image per sample.") - + num_images += 1 content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) - content = content.replace("{{image}}", self.image_token) - new_messages.append({"role": message["role"], "content": content}) + message["content"] = content.replace("{{image}}", self.image_token * image_seqlen) - return new_messages + if len(images) != num_images: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + return messages def get_mm_inputs( self, @@ -137,17 +128,6 @@ class LlavaPlugin(BasePlugin): ) -> Dict[str, Any]: return _get_mm_inputs(images, processor) - def process_model_inputs( - self, - model_inputs: Dict[str, List[Any]], - images: Sequence["ImageObject"], - feature_seqlens: Dict[str, int], - processor: Optional["ProcessorMixin"], - ) -> None: - mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor) - for key, value in mm_inputs.items(): - model_inputs[key].append(value[0]) - class PaliGemmaPlugin(BasePlugin): def process_messages( @@ -156,34 +136,35 @@ class PaliGemmaPlugin(BasePlugin): images: Sequence["ImageObject"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - image_count = 0 - new_messages = [] + num_images = 0 + messages = deepcopy(messages) for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - image_count += 1 - if image_count > 1: - raise ValueError("PaliGemma model only accepts one image per sample.") + num_images += 1 + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) - content = content.replace(IMAGE_PLACEHOLDER, "", 1) + message["content"] = content.replace("{{image}}", "") - new_messages.append({"role": message["role"], "content": content}) + if len(images) != num_images: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) - return new_messages + return messages def process_token_ids( self, input_ids: List[int], labels: Optional[List[int]], + images: Sequence["ImageObject"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], ) -> Tuple[List[int], Optional[List[int]]]: - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - image_seq_length: int = getattr(image_processor, "image_seq_length") + num_images = len(images) + image_seqlen = num_images * getattr(processor, "image_seqlen") image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) - input_ids = [image_token_id] * image_seq_length + input_ids + input_ids = [image_token_id] * image_seqlen + input_ids if labels is not None: - labels = [IGNORE_INDEX] * image_seq_length + labels + labels = [IGNORE_INDEX] * image_seqlen + labels return input_ids, labels @@ -195,21 +176,10 @@ class PaliGemmaPlugin(BasePlugin): ) -> Dict[str, Any]: mm_inputs = _get_mm_inputs(images, processor) for feature_name, feature_length in feature_seqlens.items(): - mm_inputs[feature_name] = _get_paligemma_token_type_ids(feature_length, processor) + mm_inputs[feature_name] = _get_paligemma_token_type_ids(images, feature_length, processor) return mm_inputs - def process_model_inputs( - self, - model_inputs: Dict[str, List[Any]], - images: Sequence["ImageObject"], - feature_seqlens: Dict[str, int], - processor: Optional["ProcessorMixin"], - ) -> None: - mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor) - for key, value in mm_inputs.items(): - model_inputs[key].append(value[0]) - class Qwen2vlPlugin(BasePlugin): def process_messages( @@ -223,23 +193,26 @@ class Qwen2vlPlugin(BasePlugin): if len(images) > 0: image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"] - index = 0 - new_messages = [] + num_images = 0 + messages = deepcopy(messages) for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: content = content.replace( IMAGE_PLACEHOLDER, "<|vision_start|>{}<|vision_end|>".format( - self.image_token * (image_grid_thw[index].prod() // merge_length) + self.image_token * (image_grid_thw[num_images].prod() // merge_length) ), 1, ) - index += 1 + num_images += 1 - new_messages.append({"role": message["role"], "content": content}) + message["content"] = content - return new_messages + if len(images) != num_images: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + return messages def get_mm_inputs( self, @@ -249,17 +222,6 @@ class Qwen2vlPlugin(BasePlugin): ) -> Dict[str, Any]: return _get_mm_inputs(images, processor) - def process_model_inputs( - self, - model_inputs: Dict[str, List[Any]], - images: Sequence["ImageObject"], - feature_seqlens: Dict[str, int], - processor: Optional["ProcessorMixin"], - ) -> None: - mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor) - for key, value in mm_inputs.items(): - model_inputs[key].append(value) # support multi-image - PLUGINS = { "base": BasePlugin, @@ -270,7 +232,8 @@ PLUGINS = { def get_mm_plugin(name: str, image_token: str) -> "BasePlugin": - if name not in PLUGINS: - raise ValueError("{} not found.".format(name)) + plugin_class = PLUGINS.get(name, None) + if plugin_class is None: + raise ValueError("Multimodal plugin `{}` not found.".format(name)) - return PLUGINS[name](image_token) + return plugin_class(image_token) diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index caf4a9b8..9f015b38 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -50,7 +50,7 @@ def get_preprocess_and_print_func( print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not do_generate: if data_args.packing: - if data_args.neat_packing: + if data_args.neat_packing: # hack datasets to have int32 attention mask from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence def __init__(self, data, **kwargs): @@ -67,6 +67,7 @@ def get_preprocess_and_print_func( preprocess_packed_supervised_dataset, template=template, tokenizer=tokenizer, + processor=processor, data_args=data_args, ) else: diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index 826919bc..19539f3c 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen if TYPE_CHECKING: + from PIL.Image import Image from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments @@ -36,11 +37,12 @@ def _encode_feedback_example( kl_response: Sequence[Dict[str, str]], system: Optional[str], tools: Optional[str], + images: Sequence["Image"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, -) -> Tuple[List[int], List[int], List[int], List[int], bool]: +) -> Tuple[List[int], List[int], List[int], List[int], bool, Dict[str, Any]]: if response[0]["content"]: # desired example kto_tag = True messages = prompt + [response[0]] @@ -53,6 +55,8 @@ def _encode_feedback_example( else: kl_messages = prompt + [kl_response[1]] + messages = template.mm_plugin.process_messages(messages, images, processor) + kl_messages = template.mm_plugin.process_messages(kl_messages, images, processor) prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) @@ -60,8 +64,8 @@ def _encode_feedback_example( response_ids += [tokenizer.eos_token_id] kl_response_ids += [tokenizer.eos_token_id] - prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor) - kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor) + prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor) + kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, tokenizer, processor) source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len) prompt_ids = prompt_ids[:source_len] @@ -74,8 +78,15 @@ def _encode_feedback_example( labels = [IGNORE_INDEX] * source_len + response_ids kl_input_ids = kl_prompt_ids + kl_response_ids kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids - - return input_ids, labels, kl_input_ids, kl_labels, kto_tag + extra_inputs = template.mm_plugin.get_mm_inputs( + images=images, + feature_seqlens={ + "token_type_ids": len(input_ids), + "kl_token_type_ids": len(kl_input_ids), + }, + processor=processor, + ) + return input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs def preprocess_feedback_dataset( @@ -93,13 +104,13 @@ def preprocess_feedback_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) - input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( - prompt=prompt, + input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs = _encode_feedback_example( + prompt=examples["prompt"][i], response=examples["response"][i], kl_response=kl_response[i], system=examples["system"][i], tools=examples["tools"][i], + images=examples["images"][i], template=template, tokenizer=tokenizer, processor=processor, @@ -112,15 +123,8 @@ def preprocess_feedback_dataset( model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) model_inputs["kl_labels"].append(kl_labels) model_inputs["kto_tags"].append(kto_tag) - template.mm_plugin.process_model_inputs( - model_inputs=model_inputs, - images=examples["images"][i], - feature_seqlens={ - "token_type_ids": len(input_ids), - "kl_token_type_ids": len(kl_input_ids), - }, - processor=processor, - ) + for key, value in extra_inputs.items(): + model_inputs[key].append(value) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) undesirable_num = len(model_inputs["kto_tags"]) - desirable_num diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index ad625d33..9c5565d9 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen if TYPE_CHECKING: + from PIL.Image import Image from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments @@ -35,13 +36,14 @@ def _encode_pairwise_example( response: Sequence[Dict[str, str]], system: Optional[str], tools: Optional[str], + images: Sequence["Image"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, -) -> Tuple[List[int], List[int], List[int], List[int]]: - chosen_messages = prompt + [response[0]] - rejected_messages = prompt + [response[1]] +) -> Tuple[List[int], List[int], List[int], List[int], Dict[str, Any]]: + chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor) + rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools) @@ -49,7 +51,7 @@ def _encode_pairwise_example( chosen_ids += [tokenizer.eos_token_id] rejected_ids += [tokenizer.eos_token_id] - prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor) + prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor) # consider the response is more important source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len) prompt_ids = prompt_ids[:source_len] @@ -60,8 +62,15 @@ def _encode_pairwise_example( chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids rejected_input_ids = prompt_ids + rejected_ids rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids - - return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels + extra_inputs = template.mm_plugin.get_mm_inputs( + images=images, + feature_seqlens={ + "chosen_token_type_ids": len(chosen_input_ids), + "rejected_token_type_ids": len(rejected_input_ids), + }, + processor=processor, + ) + return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs def preprocess_pairwise_dataset( @@ -78,12 +87,12 @@ def preprocess_pairwise_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) - chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( - prompt=prompt, + chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs = _encode_pairwise_example( + prompt=examples["prompt"][i], response=examples["response"][i], system=examples["system"][i], tools=examples["tools"][i], + images=examples["images"][i], template=template, tokenizer=tokenizer, processor=processor, @@ -95,15 +104,8 @@ def preprocess_pairwise_dataset( model_inputs["rejected_input_ids"].append(rejected_input_ids) model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) model_inputs["rejected_labels"].append(rejected_labels) - template.mm_plugin.process_model_inputs( - model_inputs=model_inputs, - images=examples["images"][i], - feature_seqlens={ - "chosen_token_type_ids": len(chosen_input_ids), - "rejected_token_type_ids": len(rejected_input_ids), - }, - processor=processor, - ) + for key, value in extra_inputs.items(): + model_inputs[key].append(value) return model_inputs diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 6f857d24..3981fa12 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -21,6 +21,7 @@ from .processor_utils import greedy_knapsack, infer_seqlen if TYPE_CHECKING: + from PIL.Image import Image from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments @@ -35,19 +36,18 @@ def _encode_supervised_example( response: Sequence[Dict[str, str]], system: Optional[str], tools: Optional[str], + images: Sequence["Image"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, train_on_prompt: bool, mask_history: bool, -) -> Tuple[List[int], List[int]]: - messages = prompt + response - input_ids, labels = [], [] - input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor) - +) -> Tuple[List[int], List[int], Dict[str, Any]]: + messages = template.mm_plugin.process_messages(prompt + response, images, processor) + input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) - total_length = 1 if template.efficient_eos else 0 + total_length = len(input_ids) + (1 if template.efficient_eos else 0) if mask_history: encoded_pairs = encoded_pairs[::-1] # high priority for last turns @@ -83,7 +83,10 @@ def _encode_supervised_example( input_ids += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id] - return input_ids, labels + extra_inputs = template.mm_plugin.get_mm_inputs( + images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor + ) + return input_ids, labels, extra_inputs def preprocess_supervised_dataset( @@ -101,12 +104,12 @@ def preprocess_supervised_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) - input_ids, labels = _encode_supervised_example( - prompt=prompt, + input_ids, labels, extra_inputs = _encode_supervised_example( + prompt=examples["prompt"][i], response=examples["response"][i], system=examples["system"][i], tools=examples["tools"][i], + images=examples["images"][i], template=template, tokenizer=tokenizer, processor=processor, @@ -117,12 +120,8 @@ def preprocess_supervised_dataset( model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) - template.mm_plugin.process_model_inputs( - model_inputs=model_inputs, - images=examples["images"][i], - feature_seqlens={"token_type_ids": len(input_ids)}, - processor=processor, - ) + for key, value in extra_inputs.items(): + model_inputs[key].append(value) return model_inputs @@ -131,10 +130,15 @@ def preprocess_packed_supervised_dataset( examples: Dict[str, List[Any]], template: "Template", tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], data_args: "DataArguments", ) -> Dict[str, List[Any]]: + # TODO: use `position_ids` to achieve packing # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` + if processor is not None: + raise NotImplementedError("`packing` have not been implemented for multimodal datasets.") + valid_num = 0 batch_input_ids, batch_labels = [], [] lengths = [] @@ -149,6 +153,7 @@ def preprocess_packed_supervised_dataset( response=examples["response"][i], system=examples["system"][i], tools=examples["tools"][i], + images=examples["images"][i], template=template, tokenizer=tokenizer, processor=None, diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index 49a29aa6..67cbb7b6 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen if TYPE_CHECKING: + from PIL.Image import Image from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments @@ -35,25 +36,30 @@ def _encode_unsupervised_example( response: Sequence[Dict[str, str]], system: Optional[str], tools: Optional[str], + images: Sequence["Image"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, -) -> Tuple[List[int], List[int]]: +) -> Tuple[List[int], List[int], Dict[str, Any]]: if len(response) == 1: messages = prompt + response else: messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] + messages = template.mm_plugin.process_messages(messages, images, processor) input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) if template.efficient_eos: labels += [tokenizer.eos_token_id] - input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, tokenizer, processor) + input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, tokenizer, processor) source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len) input_ids = input_ids[:source_len] labels = labels[:target_len] - return input_ids, labels + extra_inputs = template.mm_plugin.get_mm_inputs( + images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor + ) + return input_ids, labels, extra_inputs def preprocess_unsupervised_dataset( @@ -70,12 +76,12 @@ def preprocess_unsupervised_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) - input_ids, labels = _encode_unsupervised_example( - prompt=prompt, + input_ids, labels, extra_inputs = _encode_unsupervised_example( + prompt=examples["prompt"][i], response=examples["response"][i], system=examples["system"][i], tools=examples["tools"][i], + images=examples["images"][i], template=template, tokenizer=tokenizer, processor=processor, @@ -84,12 +90,8 @@ def preprocess_unsupervised_dataset( model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) - template.mm_plugin.process_model_inputs( - model_inputs=model_inputs, - images=examples["images"][i], - feature_seqlens={"token_type_ids": len(input_ids)}, - processor=processor, - ) + for key, value in extra_inputs.items(): + model_inputs[key].append(value) return model_inputs diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 63564e8f..70f24114 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -15,6 +15,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from transformers.utils.versions import require_version + from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.logging import get_logger from .data_utils import Role @@ -347,6 +349,11 @@ def get_template_and_fix_tokenizer( name: Optional[str] = None, tool_format: Optional[str] = None, ) -> Template: + if name == "qwen2_vl": + require_version( + "transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git" + ) + if name is None: template = TEMPLATES["empty"] # placeholder else: @@ -357,8 +364,8 @@ def get_template_and_fix_tokenizer( if tool_format is not None: logger.info("Using tool format: {}.".format(tool_format)) eos_slots = [] if template.efficient_eos else [{"eos_token"}] - template.format_tools = ToolFormatter(tool_format=tool_format) template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format) + template.format_tools = ToolFormatter(tool_format=tool_format) stop_words = template.stop_words if template.replace_eos: diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index efda86f5..0d7133fe 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -138,3 +138,17 @@ class GLM4ToolUtils(ToolUtils): return content return [(tool_name, json.dumps(arguments, ensure_ascii=False))] + + +TOOLS = { + "default": DefaultToolUtils(), + "glm4": GLM4ToolUtils(), +} + + +def get_tool_utils(name: str) -> "ToolUtils": + tool_utils = TOOLS.get(name, None) + if tool_utils is None: + raise ValueError("Tool utils `{}` not found.".format(name)) + + return tool_utils diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index d3105e65..8a3c125c 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -195,6 +195,9 @@ def is_gpu_or_npu_available() -> bool: def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": + r""" + Casts a torch tensor or a numpy array to a numpy array. + """ if isinstance(inputs, torch.Tensor): inputs = inputs.cpu() if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4 @@ -206,6 +209,9 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": def skip_check_imports() -> None: + r""" + Avoids flash attention import error in custom model files. + """ if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: transformers.dynamic_module_utils.check_imports = get_relative_imports diff --git a/src/llamafactory/launcher.py b/src/llamafactory/launcher.py index 65e0b68f..b93f2ad1 100644 --- a/src/llamafactory/launcher.py +++ b/src/llamafactory/launcher.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from llamafactory.train.tuner import run_exp +from llamafactory.train.tuner import run_exp # use absolute import def launch(): diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 374748d0..8ca8efdf 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -25,6 +25,7 @@ from .model_utils.misc import register_autoclass from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model from .model_utils.valuehead import load_valuehead_params +from .model_utils.visual import get_image_seqlen from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model @@ -65,6 +66,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": Note: including inplace operation of model_args. """ init_kwargs = _get_init_kwargs(model_args) + config = load_config(model_args) try: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, @@ -96,6 +98,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": try: processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) setattr(processor, "tokenizer", tokenizer) + setattr(processor, "image_seqlen", get_image_seqlen(config)) except Exception: processor = None diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index b3103a2c..1fbf3400 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -82,7 +82,7 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL): def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None: r""" - Casts projector output to half precision for quantized VLMs. + Casts projector output to half precision for fine-tuning quantized VLMs. """ def _mm_projector_forward_post_hook( @@ -136,6 +136,22 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni return forbidden_modules +def get_image_seqlen(config: "PretrainedConfig") -> int: + r""" + Computes the number of special tokens per image. + """ + if getattr(config, "model_type", None) == "llava": + image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2 + if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token + image_seqlen += 1 + elif getattr(config, "model_type", None) == "paligemma": + image_seqlen = config.vision_config.num_image_tokens + elif getattr(config, "model_type", None) == "qwen2_vl": # variable length + image_seqlen = -1 + + return image_seqlen + + def patch_target_modules( config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] ) -> Union[str, List[str]]: diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index a2685f33..94c4320d 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, List, Optional -from ...data import CustomDataCollatorForSeq2Seq, get_dataset +from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..callbacks import fix_valuehead_checkpoint @@ -45,7 +45,7 @@ def run_ppo( model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training - data_collator = CustomDataCollatorForSeq2Seq(tokenizer=tokenizer) + data_collator = MultiModalDataCollatorForSeq2Seq(tokenizer=tokenizer) # Create reference model and reward model ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 52061075..c2950605 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -13,8 +13,7 @@ # limitations under the License. import os -from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Tuple import pytest import torch @@ -26,7 +25,7 @@ from llamafactory.model import load_tokenizer if TYPE_CHECKING: - from transformers import ProcessorMixin + from transformers import PreTrainedTokenizer, ProcessorMixin from transformers.image_processing_utils import BaseImageProcessor @@ -34,13 +33,20 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None) TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") -MESSAGES = [ +MM_MESSAGES = [ {"role": "user", "content": "What is in this image?"}, {"role": "assistant", "content": "A cat."}, ] +TEXT_MESSAGES = [ + {"role": "user", "content": "How are you"}, + {"role": "assistant", "content": "I am fine!"}, +] + IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))] +NO_IMAGES = [] + INPUT_IDS = [0, 1, 2, 3, 4] LABELS = [0, 1, 2, 3, 4] @@ -53,99 +59,110 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: return image_processor(images=IMAGES, return_tensors="pt") -def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]): +def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None: assert batch_a.keys() == batch_b.keys() for key in batch_a.keys(): - if isinstance(batch_a[key], list): - assert len(batch_a[key]) == len(batch_b[key]) - for i in range(len(batch_a[key])): - if isinstance(batch_a[key][i], torch.Tensor): - assert torch.allclose(batch_a[key][i], batch_b[key][i], rtol=1e-4, atol=1e-5) - else: - assert batch_a[key][i] == batch_b[key][i] - elif isinstance(batch_a[key], torch.Tensor): + if isinstance(batch_a[key], torch.Tensor): assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5) else: - raise NotImplementedError + assert batch_a[key] == batch_b[key] + + +def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenizer", "ProcessorMixin"]: + model_args = ModelArguments(model_name_or_path=model_name_or_path) + tokenizer_module = load_tokenizer(model_args) + return tokenizer_module["tokenizer"], tokenizer_module["processor"] def test_base_plugin(): - model_args = ModelArguments(model_name_or_path=TINY_LLAMA) - tokenizer_module = load_tokenizer(model_args) - tokenizer = tokenizer_module["tokenizer"] - processor = tokenizer_module["processor"] - + tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA) base_plugin = get_mm_plugin(name="base", image_token="") - model_inputs = defaultdict(list) - base_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) - - assert base_plugin.process_messages(MESSAGES, IMAGES, processor) - assert base_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS) + # test mm_messages + assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES + assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) _is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {}) - _is_close(model_inputs, {}) + # test text_messages + assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES + assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) + _is_close(base_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {}) def test_llava_plugin(): - model_args = ModelArguments(model_name_or_path="llava-hf/llava-1.5-7b-hf") - tokenizer_module = load_tokenizer(model_args) - tokenizer = tokenizer_module["tokenizer"] - processor = tokenizer_module["processor"] + tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf") + image_seqlen = 576 mm_inputs = _get_mm_inputs(processor) - expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()} + expected_mm_messages = [ + {key: value.replace("", "" * image_seqlen) for key, value in message.items()} + for message in MM_MESSAGES + ] llava_plugin = get_mm_plugin(name="llava", image_token="") - model_inputs = defaultdict(list) - llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) - - assert llava_plugin.process_messages(MESSAGES, IMAGES, processor) - assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS) + # test mm_messages + assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages + assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) _is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) - _is_close(model_inputs, expected_model_inputs) + # test text_messages + assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES + assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) + _is_close(llava_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {"pixel_values": None}) @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") def test_paligemma_plugin(): - model_args = ModelArguments(model_name_or_path="google/paligemma-3b-pt-224") - tokenizer_module = load_tokenizer(model_args) - tokenizer = tokenizer_module["tokenizer"] - processor = tokenizer_module["processor"] - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - image_seq_length: int = getattr(image_processor, "image_seq_length") + tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224") + image_seqlen = 256 mm_inputs = _get_mm_inputs(processor) - mm_inputs["token_type_ids"] = [[0] * image_seq_length + [1] * (1024 - image_seq_length)] - expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()} - expected_input_ids = [tokenizer.convert_tokens_to_ids("")] * image_seq_length + INPUT_IDS - expected_labels = [-100] * image_seq_length + LABELS + mm_inputs["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)] + expected_mm_messages = [ + {key: value.replace("", "") for key, value in message.items()} for message in MM_MESSAGES + ] + expected_input_ids = [tokenizer.convert_tokens_to_ids("")] * image_seqlen + INPUT_IDS + expected_labels = [-100] * image_seqlen + LABELS paligemma_plugin = get_mm_plugin(name="paligemma", image_token="") - model_inputs = defaultdict(list) - paligemma_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) - - assert paligemma_plugin.process_messages(MESSAGES, IMAGES, processor) - assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == ( + # test mm_messages + assert paligemma_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages + assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == ( expected_input_ids, expected_labels, ) _is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) - _is_close(model_inputs, expected_model_inputs) + # test text_messages + assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES + assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == ( + INPUT_IDS, + LABELS, + ) + _is_close( + paligemma_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), + {"pixel_values": None, "token_type_ids": [[1] * 1024]}, + ) def test_qwen2_vl_plugin(): - model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") - tokenizer_module = load_tokenizer(model_args) - tokenizer = tokenizer_module["tokenizer"] - processor = tokenizer_module["processor"] + tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") + image_seqlen = 4 mm_inputs = _get_mm_inputs(processor) - expected_model_inputs = {key: [value] for key, value in mm_inputs.items()} + expected_mm_messages = [ + { + key: value.replace("", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen)) + for key, value in message.items() + } + for message in MM_MESSAGES + ] - llava_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>") - model_inputs = defaultdict(list) - llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) - - assert llava_plugin.process_messages(MESSAGES, IMAGES, processor) - assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS) - _is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) - _is_close(model_inputs, expected_model_inputs) + qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>") + # test mm_messages + assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages + assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) + _is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) + # test text_messages + assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES + assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) + _is_close( + qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), + {"pixel_values": None, "image_grid_thw": None}, + )