fix mixed mm inputs and rlhf-v

Former-commit-id: 9967ccb3aef3ca557ad6eafb78c6c99866857008
This commit is contained in:
hiyouga 2024-09-01 20:52:47 +08:00
parent 34dc36462c
commit cb776752f6
20 changed files with 306 additions and 277 deletions

View File

@ -33,7 +33,7 @@ Dependency graph:
transformers>=4.41.2,<=4.44.3
"""
from .cli import VERSION
from .extras.env import VERSION
__version__ = VERSION

View File

@ -89,11 +89,9 @@ class HuggingfaceEngine(BaseEngine):
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
)
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
if image is not None:
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, [image], tokenizer, processor)
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)

View File

@ -124,9 +124,7 @@ class VllmEngine(BaseEngine):
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")

View File

@ -13,8 +13,8 @@
# limitations under the License.
from .collator import (
CustomDataCollatorForSeq2Seq,
KTODataCollatorWithPadding,
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
)
@ -24,8 +24,8 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"CustomDataCollatorForSeq2Seq",
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role",

View File

@ -62,44 +62,49 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator for custom models (like Qwen2-VL).
Data collator that supports VLMs.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
image_grid_thw = None # TODO: better handle various VLMs
if "image_grid_thw" in features[0]:
image_grid_thw_list = [
torch.Tensor(feature["image_grid_thw"]).long()
for feature in features
if feature["image_grid_thw"][0][0] > 0
]
pixel_values_list = [
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
]
if image_grid_thw_list:
image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
pixel_values = torch.cat(pixel_values_list, dim=0)
else:
image_grid_thw = None
pixel_values = None
if "token_type_ids" in features[0].keys():
for feature in features:
feature["token_type_ids"] = feature["token_type_ids"][0]
features = [
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
for feature in features
]
extra_features = {}
if "pixel_values" in features[0].keys():
pixel_values = []
for feature in features:
if feature["pixel_values"] is None:
pixel_values.append(torch.zeros(0, dtype=torch.float))
else:
pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float))
features = super().__call__(features)
if image_grid_thw is not None:
features["image_grid_thw"] = image_grid_thw
features["pixel_values"] = pixel_values
extra_features["pixel_values"] = torch.cat(pixel_values, dim=0)
if extra_features["pixel_values"].numel() == 0:
extra_features["pixel_values"] = None
if "image_grid_thw" in features[0].keys():
image_grid_thw = []
for feature in features:
if feature["image_grid_thw"] is None:
image_grid_thw.append(torch.zeros(0, dtype=torch.long))
else:
image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long))
extra_features["image_grid_thw"] = torch.cat(pixel_values, dim=0)
if extra_features["image_grid_thw"].numel() == 0:
extra_features["image_grid_thw"] = None
features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update({key: value for key, value in extra_features.items() if value is not None})
return features
@dataclass
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
@ -117,7 +122,7 @@ class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
@dataclass
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
@ -152,7 +157,7 @@ class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
@dataclass
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""

View File

@ -16,16 +16,16 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from .data_utils import SLOTS
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
from .tool_utils import get_tool_utils
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default", "glm4"]] = None
tool_format: Optional[str] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
@ -81,12 +81,7 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
if self.tool_format == "default":
self.slots = DefaultToolUtils.get_function_slots() + self.slots
elif self.tool_format == "glm4":
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
@ -119,22 +114,15 @@ class FunctionFormatter(Formatter):
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format == "default":
self._tool_formatter = DefaultToolUtils.tool_formatter
self._tool_extractor = DefaultToolUtils.tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = GLM4ToolUtils.tool_formatter
self._tool_extractor = GLM4ToolUtils.tool_extractor
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
self.tool_utils = get_tool_utils(self.tool_format)
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
return [""]
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
return self._tool_extractor(content)
return self.tool_utils.tool_extractor(content)

View File

@ -1,3 +1,4 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from PIL.Image import Image
@ -27,32 +28,33 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if len(images) != 0:
image_inputs = image_processor(images=images, return_tensors="pt")
else:
image = Image.new("RGB", (56, 56), (255, 255, 255))
else: # add NoneType for fake images
image = Image.new("RGB", (64, 64), (255, 255, 255))
image_inputs = image_processor(images=[image], return_tensors="pt")
if "image_grid_thw" in image_inputs: # fake image for qwen2-vl
image_inputs["image_grid_thw"][0][0] = 0
image_inputs = {key: None for key in image_inputs.keys()}
return image_inputs
def _get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
def _get_paligemma_token_type_ids(
images: Sequence["ImageObject"], input_len: int, processor: "ProcessorMixin"
) -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
Returns:
token_type_ids: shape (1, seq_len)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_seq_length: int = getattr(image_processor, "image_seq_length")
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen")
return [[0] * image_seqlen + [1] * (input_len - image_seqlen)]
class BasePlugin:
@ -74,6 +76,7 @@ class BasePlugin:
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageObject"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
@ -93,18 +96,6 @@ class BasePlugin:
"""
return {}
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
r"""
Appends multimodal inputs to model inputs for VLMs.
"""
return
class LlavaPlugin(BasePlugin):
def process_messages(
@ -113,21 +104,21 @@ class LlavaPlugin(BasePlugin):
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
image_count = 0
new_messages = []
num_images = 0
image_seqlen = getattr(processor, "image_seqlen")
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_count += 1
if image_count > 1:
raise ValueError("Llava model only accepts one image per sample.")
num_images += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
content = content.replace("{{image}}", self.image_token)
new_messages.append({"role": message["role"], "content": content})
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
return new_messages
if len(images) != num_images:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
def get_mm_inputs(
self,
@ -137,17 +128,6 @@ class LlavaPlugin(BasePlugin):
) -> Dict[str, Any]:
return _get_mm_inputs(images, processor)
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
for key, value in mm_inputs.items():
model_inputs[key].append(value[0])
class PaliGemmaPlugin(BasePlugin):
def process_messages(
@ -156,34 +136,35 @@ class PaliGemmaPlugin(BasePlugin):
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
image_count = 0
new_messages = []
num_images = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_count += 1
if image_count > 1:
raise ValueError("PaliGemma model only accepts one image per sample.")
num_images += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
message["content"] = content.replace("{{image}}", "")
new_messages.append({"role": message["role"], "content": content})
if len(images) != num_images:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return new_messages
return messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageObject"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_seq_length: int = getattr(image_processor, "image_seq_length")
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen")
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seq_length + input_ids
input_ids = [image_token_id] * image_seqlen + input_ids
if labels is not None:
labels = [IGNORE_INDEX] * image_seq_length + labels
labels = [IGNORE_INDEX] * image_seqlen + labels
return input_ids, labels
@ -195,21 +176,10 @@ class PaliGemmaPlugin(BasePlugin):
) -> Dict[str, Any]:
mm_inputs = _get_mm_inputs(images, processor)
for feature_name, feature_length in feature_seqlens.items():
mm_inputs[feature_name] = _get_paligemma_token_type_ids(feature_length, processor)
mm_inputs[feature_name] = _get_paligemma_token_type_ids(images, feature_length, processor)
return mm_inputs
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
for key, value in mm_inputs.items():
model_inputs[key].append(value[0])
class Qwen2vlPlugin(BasePlugin):
def process_messages(
@ -223,23 +193,26 @@ class Qwen2vlPlugin(BasePlugin):
if len(images) > 0:
image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"]
index = 0
new_messages = []
num_images = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(
IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[index].prod() // merge_length)
self.image_token * (image_grid_thw[num_images].prod() // merge_length)
),
1,
)
index += 1
num_images += 1
new_messages.append({"role": message["role"], "content": content})
message["content"] = content
return new_messages
if len(images) != num_images:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
def get_mm_inputs(
self,
@ -249,17 +222,6 @@ class Qwen2vlPlugin(BasePlugin):
) -> Dict[str, Any]:
return _get_mm_inputs(images, processor)
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
for key, value in mm_inputs.items():
model_inputs[key].append(value) # support multi-image
PLUGINS = {
"base": BasePlugin,
@ -270,7 +232,8 @@ PLUGINS = {
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
if name not in PLUGINS:
raise ValueError("{} not found.".format(name))
plugin_class = PLUGINS.get(name, None)
if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name))
return PLUGINS[name](image_token)
return plugin_class(image_token)

View File

@ -50,7 +50,7 @@ def get_preprocess_and_print_func(
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs):
@ -67,6 +67,7 @@ def get_preprocess_and_print_func(
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
else:

View File

@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
@ -36,11 +37,12 @@ def _encode_feedback_example(
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["Image"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
) -> Tuple[List[int], List[int], List[int], List[int], bool, Dict[str, Any]]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
@ -53,6 +55,8 @@ def _encode_feedback_example(
else:
kl_messages = prompt + [kl_response[1]]
messages = template.mm_plugin.process_messages(messages, images, processor)
kl_messages = template.mm_plugin.process_messages(kl_messages, images, processor)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
@ -60,8 +64,8 @@ def _encode_feedback_example(
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor)
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, tokenizer, processor)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len]
@ -74,8 +78,15 @@ def _encode_feedback_example(
labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
extra_inputs = template.mm_plugin.get_mm_inputs(
images=images,
feature_seqlens={
"token_type_ids": len(input_ids),
"kl_token_type_ids": len(kl_input_ids),
},
processor=processor,
)
return input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs
def preprocess_feedback_dataset(
@ -93,13 +104,13 @@ def preprocess_feedback_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=prompt,
input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs = _encode_feedback_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
kl_response=kl_response[i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
@ -112,15 +123,8 @@ def preprocess_feedback_dataset(
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={
"token_type_ids": len(input_ids),
"kl_token_type_ids": len(kl_input_ids),
},
processor=processor,
)
for key, value in extra_inputs.items():
model_inputs[key].append(value)
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num

View File

@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
@ -35,13 +36,14 @@ def _encode_pairwise_example(
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["Image"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]:
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
) -> Tuple[List[int], List[int], List[int], List[int], Dict[str, Any]]:
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor)
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor)
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
@ -49,7 +51,7 @@ def _encode_pairwise_example(
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
# consider the response is more important
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
prompt_ids = prompt_ids[:source_len]
@ -60,8 +62,15 @@ def _encode_pairwise_example(
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
extra_inputs = template.mm_plugin.get_mm_inputs(
images=images,
feature_seqlens={
"chosen_token_type_ids": len(chosen_input_ids),
"rejected_token_type_ids": len(rejected_input_ids),
},
processor=processor,
)
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs
def preprocess_pairwise_dataset(
@ -78,12 +87,12 @@ def preprocess_pairwise_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=prompt,
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs = _encode_pairwise_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
@ -95,15 +104,8 @@ def preprocess_pairwise_dataset(
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={
"chosen_token_type_ids": len(chosen_input_ids),
"rejected_token_type_ids": len(rejected_input_ids),
},
processor=processor,
)
for key, value in extra_inputs.items():
model_inputs[key].append(value)
return model_inputs

View File

@ -21,6 +21,7 @@ from .processor_utils import greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
@ -35,19 +36,18 @@ def _encode_supervised_example(
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["Image"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]:
messages = prompt + response
input_ids, labels = [], []
input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor)
) -> Tuple[List[int], List[int], Dict[str, Any]]:
messages = template.mm_plugin.process_messages(prompt + response, images, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = 1 if template.efficient_eos else 0
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
@ -83,7 +83,10 @@ def _encode_supervised_example(
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
return input_ids, labels
extra_inputs = template.mm_plugin.get_mm_inputs(
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
)
return input_ids, labels, extra_inputs
def preprocess_supervised_dataset(
@ -101,12 +104,12 @@ def preprocess_supervised_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels = _encode_supervised_example(
prompt=prompt,
input_ids, labels, extra_inputs = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
@ -117,12 +120,8 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={"token_type_ids": len(input_ids)},
processor=processor,
)
for key, value in extra_inputs.items():
model_inputs[key].append(value)
return model_inputs
@ -131,10 +130,15 @@ def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
if processor is not None:
raise NotImplementedError("`packing` have not been implemented for multimodal datasets.")
valid_num = 0
batch_input_ids, batch_labels = [], []
lengths = []
@ -149,6 +153,7 @@ def preprocess_packed_supervised_dataset(
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=None,

View File

@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
@ -35,25 +36,30 @@ def _encode_unsupervised_example(
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["Image"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int]]:
) -> Tuple[List[int], List[int], Dict[str, Any]]:
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, tokenizer, processor)
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, tokenizer, processor)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
return input_ids, labels
extra_inputs = template.mm_plugin.get_mm_inputs(
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
)
return input_ids, labels, extra_inputs
def preprocess_unsupervised_dataset(
@ -70,12 +76,12 @@ def preprocess_unsupervised_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels = _encode_unsupervised_example(
prompt=prompt,
input_ids, labels, extra_inputs = _encode_unsupervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
images=examples["images"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
@ -84,12 +90,8 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={"token_type_ids": len(input_ids)},
processor=processor,
)
for key, value in extra_inputs.items():
model_inputs[key].append(value)
return model_inputs

View File

@ -15,6 +15,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from .data_utils import Role
@ -347,6 +349,11 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None,
tool_format: Optional[str] = None,
) -> Template:
if name == "qwen2_vl":
require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
if name is None:
template = TEMPLATES["empty"] # placeholder
else:
@ -357,8 +364,8 @@ def get_template_and_fix_tokenizer(
if tool_format is not None:
logger.info("Using tool format: {}.".format(tool_format))
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_tools = ToolFormatter(tool_format=tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
template.format_tools = ToolFormatter(tool_format=tool_format)
stop_words = template.stop_words
if template.replace_eos:

View File

@ -138,3 +138,17 @@ class GLM4ToolUtils(ToolUtils):
return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
}
def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None)
if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name))
return tool_utils

View File

@ -195,6 +195,9 @@ def is_gpu_or_npu_available() -> bool:
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r"""
Casts a torch tensor or a numpy array to a numpy array.
"""
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
@ -206,6 +209,9 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.train.tuner import run_exp
from llamafactory.train.tuner import run_exp # use absolute import
def launch():

View File

@ -25,6 +25,7 @@ from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
from .model_utils.visual import get_image_seqlen
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
@ -65,6 +66,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
@ -96,6 +98,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
except Exception:
processor = None

View File

@ -82,7 +82,7 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Casts projector output to half precision for quantized VLMs.
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
def _mm_projector_forward_post_hook(
@ -136,6 +136,22 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
return forbidden_modules
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""
Computes the number of special tokens per image.
"""
if getattr(config, "model_type", None) == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif getattr(config, "model_type", None) == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
elif getattr(config, "model_type", None) == "qwen2_vl": # variable length
image_seqlen = -1
return image_seqlen
def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]:

View File

@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional
from ...data import CustomDataCollatorForSeq2Seq, get_dataset
from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint
@ -45,7 +45,7 @@ def run_ppo(
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = CustomDataCollatorForSeq2Seq(tokenizer=tokenizer)
data_collator = MultiModalDataCollatorForSeq2Seq(tokenizer=tokenizer)
# Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)

View File

@ -13,8 +13,7 @@
# limitations under the License.
import os
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Tuple
import pytest
import torch
@ -26,7 +25,7 @@ from llamafactory.model import load_tokenizer
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
@ -34,13 +33,20 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
MESSAGES = [
MM_MESSAGES = [
{"role": "user", "content": "<image>What is in this image?"},
{"role": "assistant", "content": "A cat."},
]
TEXT_MESSAGES = [
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "I am fine!"},
]
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
NO_IMAGES = []
INPUT_IDS = [0, 1, 2, 3, 4]
LABELS = [0, 1, 2, 3, 4]
@ -53,99 +59,110 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
return image_processor(images=IMAGES, return_tensors="pt")
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]):
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
assert batch_a.keys() == batch_b.keys()
for key in batch_a.keys():
if isinstance(batch_a[key], list):
assert len(batch_a[key]) == len(batch_b[key])
for i in range(len(batch_a[key])):
if isinstance(batch_a[key][i], torch.Tensor):
assert torch.allclose(batch_a[key][i], batch_b[key][i], rtol=1e-4, atol=1e-5)
else:
assert batch_a[key][i] == batch_b[key][i]
elif isinstance(batch_a[key], torch.Tensor):
if isinstance(batch_a[key], torch.Tensor):
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
else:
raise NotImplementedError
assert batch_a[key] == batch_b[key]
def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenizer", "ProcessorMixin"]:
model_args = ModelArguments(model_name_or_path=model_name_or_path)
tokenizer_module = load_tokenizer(model_args)
return tokenizer_module["tokenizer"], tokenizer_module["processor"]
def test_base_plugin():
model_args = ModelArguments(model_name_or_path=TINY_LLAMA)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>")
model_inputs = defaultdict(list)
base_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
assert base_plugin.process_messages(MESSAGES, IMAGES, processor)
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
# test mm_messages
assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {})
_is_close(model_inputs, {})
# test text_messages
assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {})
def test_llava_plugin():
model_args = ModelArguments(model_name_or_path="llava-hf/llava-1.5-7b-hf")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
image_seqlen = 576
mm_inputs = _get_mm_inputs(processor)
expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()}
expected_mm_messages = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
model_inputs = defaultdict(list)
llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
assert llava_plugin.process_messages(MESSAGES, IMAGES, processor)
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
# test mm_messages
assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
_is_close(model_inputs, expected_model_inputs)
# test text_messages
assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {"pixel_values": None})
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin():
model_args = ModelArguments(model_name_or_path="google/paligemma-3b-pt-224")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_seq_length: int = getattr(image_processor, "image_seq_length")
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
image_seqlen = 256
mm_inputs = _get_mm_inputs(processor)
mm_inputs["token_type_ids"] = [[0] * image_seq_length + [1] * (1024 - image_seq_length)]
expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()}
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seq_length + INPUT_IDS
expected_labels = [-100] * image_seq_length + LABELS
mm_inputs["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
expected_mm_messages = [
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
]
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
expected_labels = [-100] * image_seqlen + LABELS
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
model_inputs = defaultdict(list)
paligemma_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
assert paligemma_plugin.process_messages(MESSAGES, IMAGES, processor)
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (
# test mm_messages
assert paligemma_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
_is_close(model_inputs, expected_model_inputs)
# test text_messages
assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
INPUT_IDS,
LABELS,
)
_is_close(
paligemma_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
{"pixel_values": None, "token_type_ids": [[1] * 1024]},
)
def test_qwen2_vl_plugin():
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
image_seqlen = 4
mm_inputs = _get_mm_inputs(processor)
expected_model_inputs = {key: [value] for key, value in mm_inputs.items()}
expected_mm_messages = [
{
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
for key, value in message.items()
}
for message in MM_MESSAGES
]
llava_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
model_inputs = defaultdict(list)
llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
assert llava_plugin.process_messages(MESSAGES, IMAGES, processor)
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
_is_close(model_inputs, expected_model_inputs)
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
# test mm_messages
assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
# test text_messages
assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(
qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
{"pixel_values": None, "image_grid_thw": None},
)