lazy image load

Former-commit-id: 47ea97fb1ba77de2e8a561904aa8fdc27c3f5025
This commit is contained in:
hiyouga 2024-09-04 02:27:08 +08:00
parent 5ef58eb655
commit 22deca0e9e
19 changed files with 353 additions and 366 deletions

View File

@ -156,7 +156,7 @@ class HuggingfaceEngine(BaseEngine):
if image is not None: if image is not None:
mm_inputs = template.mm_plugin.get_mm_inputs( mm_inputs = template.mm_plugin.get_mm_inputs(
images=[image], feature_seqlens={"token_type_ids": prompt_length}, processor=processor images=[image], imglens=[1], seqlens=[prompt_length], processor=processor
) )
for key, value in mm_inputs.items(): for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value) value = value if isinstance(value, torch.Tensor) else torch.tensor(value)

View File

@ -14,9 +14,7 @@
import os import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from datasets import Features
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .data_utils import Role from .data_utils import Role
@ -27,16 +25,24 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments
from .mm_plugin import ImageInput
from .parser import DatasetAttr from .parser import DatasetAttr
logger = get_logger(__name__) logger = get_logger(__name__)
def _convert_images(images: Sequence[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: def _convert_images(
images: Sequence["ImageInput"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
r""" r"""
Optionally concatenates image path to dataset dir when loading from local disk. Optionally concatenates image path to dataset dir when loading from local disk.
""" """
if len(images) == 0:
return None
images = images[:] images = images[:]
if dataset_attr.load_from in ["script", "file"]: if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)): for i in range(len(images)):
@ -47,66 +53,67 @@ def _convert_images(images: Sequence[Any], dataset_attr: "DatasetAttr", data_arg
def convert_alpaca( def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" example: Dict[str, Any],
) -> Dict[str, List[Any]]: dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r""" r"""
Converts alpaca format dataset to the standard format. Converts alpaca format dataset to the standard format.
""" """
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} prompt = []
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
for old_prompt, old_response in example[dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
query = []
if dataset_attr.prompt and example[dataset_attr.prompt]:
query.append(example[dataset_attr.prompt])
if dataset_attr.query and example[dataset_attr.query]:
query.append(example[dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], str)
and isinstance(example[dataset_attr.rejected], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
]
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
else: # unsupervised
response = []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])): output = {
prompt = [] "_prompt": prompt,
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): "_response": response,
for old_prompt, old_response in examples[dataset_attr.history][i]: "_system": example[dataset_attr.system] if dataset_attr.system else "",
prompt.append({"role": Role.USER.value, "content": old_prompt}) "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
}
content = [] return output
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
content.append(examples[dataset_attr.prompt][i])
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str)
and isinstance(examples[dataset_attr.rejected][i], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: # unsupervised
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
def convert_sharegpt( def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" example: Dict[str, Any],
) -> Dict[str, List[Any]]: dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r""" r"""
Converts sharegpt format dataset to the standard format. Converts sharegpt format dataset to the standard format.
""" """
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = { tag_mapping = {
dataset_attr.user_tag: Role.USER.value, dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value, dataset_attr.assistant_tag: Role.ASSISTANT.value,
@ -117,74 +124,77 @@ def convert_sharegpt(
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags) accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]): messages = example[dataset_attr.messages]
if len(messages) == 0: if (
continue dataset_attr.system_tag
and len(messages) != 0
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
):
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[dataset_attr.system] if dataset_attr.system else ""
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag: aligned_messages = []
system = messages[0][dataset_attr.content_tag] broken_data = False
messages = messages[1:] for turn_idx, message in enumerate(messages):
else: if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
system = examples[dataset_attr.system][i] if dataset_attr.system else "" logger.warning("Invalid role tag in {}.".format(messages))
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example aligned_messages.append(
prompt = aligned_messages[:-1] {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
response = aligned_messages[-1:] )
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict)
and isinstance(examples[dataset_attr.rejected][i], dict)
): # pairwise example
chosen = examples[dataset_attr.chosen][i]
rejected = examples[dataset_attr.rejected][i]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
prompt = aligned_messages if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
response = [ dataset_attr.ranking and len(aligned_messages) % 2 == 0
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]}, ):
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]}, logger.warning("Invalid message count in {}.".format(messages))
] broken_data = True
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data: if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
logger.warning("Skipping this abnormal example.") prompt = aligned_messages[:-1]
continue response = aligned_messages[-1:]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], dict)
and isinstance(example[dataset_attr.rejected], dict)
): # pairwise example
chosen = example[dataset_attr.chosen]
rejected = example[dataset_attr.rejected]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
outputs["prompt"].append(prompt) prompt = aligned_messages
outputs["response"].append(response) response = [
outputs["system"].append(system) {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) ]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
return outputs if broken_data:
logger.warning("Skipping this abnormal example.")
prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
}
return output
def align_dataset( def align_dataset(
@ -195,11 +205,11 @@ def align_dataset(
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""
Aligned dataset: Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1) _prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..." _system: "..."
tools: "...", _tools: "...",
images: [], _images: [],
""" """
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args) convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
@ -207,19 +217,6 @@ def align_dataset(
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args) convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
}
)
kwargs = {} kwargs = {}
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
@ -230,8 +227,7 @@ def align_dataset(
return dataset.map( return dataset.map(
convert_func, convert_func,
batched=True, batched=False,
remove_columns=column_names, remove_columns=column_names,
features=features,
**kwargs, **kwargs,
) )

View File

@ -16,12 +16,18 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Literal, Sequence from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch import torch
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
if TYPE_CHECKING:
from transformers import ProcessorMixin
from .template import Template
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r""" r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
@ -65,41 +71,29 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r""" r"""
Data collator that supports VLMs. Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels and images.
""" """
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
if "token_type_ids" in features[0].keys(): batch_images, batch_imglens, batch_seqlens = [], [], []
for feature in features: for feature in features:
feature["token_type_ids"] = feature["token_type_ids"][0] images = feature.pop("images") or [] # avoid NoneType
batch_images.extend(images)
batch_imglens.append(len(images))
batch_seqlens.append(len(feature["input_ids"]))
extra_features = {} mm_inputs = self.template.mm_plugin.get_mm_inputs(batch_images, batch_imglens, batch_seqlens, self.processor)
if "pixel_values" in features[0].keys(): if "token_type_ids" in mm_inputs:
pixel_values = [] token_type_ids = mm_inputs.pop("token_type_ids")
for feature in features: for i, feature in enumerate(features):
if feature["pixel_values"] is None: feature["token_type_ids"] = token_type_ids[i]
pixel_values.append(torch.zeros(0, dtype=torch.float))
else:
pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float))
extra_features["pixel_values"] = torch.cat(pixel_values, dim=0)
if extra_features["pixel_values"].numel() == 0:
extra_features["pixel_values"] = None
if "image_grid_thw" in features[0].keys():
image_grid_thw = []
for feature in features:
if feature["image_grid_thw"] is None:
image_grid_thw.append(torch.zeros(0, dtype=torch.long))
else:
image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long))
extra_features["image_grid_thw"] = torch.cat(image_grid_thw, dim=0)
if extra_features["image_grid_thw"].numel() == 0:
extra_features["image_grid_thw"] = None
features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features]
features: Dict[str, "torch.Tensor"] = super().__call__(features) features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update({key: value for key, value in extra_features.items() if value is not None}) features.update(mm_inputs)
return features return features
@ -141,16 +135,8 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"input_ids": feature["{}_input_ids".format(key)], "input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)], "attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)], "labels": feature["{}_labels".format(key)],
"images": feature["images"],
} }
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
if "pixel_values" in feature: # image data are same for chosen and rejected
target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
concatenated_features.append(target_feature) concatenated_features.append(target_feature)
return super().__call__(concatenated_features) return super().__call__(concatenated_features)
@ -171,22 +157,14 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"input_ids": feature["input_ids"], "input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"], "attention_mask": feature["attention_mask"],
"labels": feature["labels"], "labels": feature["labels"],
"images": feature["images"],
} }
kl_feature = { kl_feature = {
"input_ids": feature["kl_input_ids"], "input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"], "attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"], "labels": feature["kl_labels"],
"images": feature["images"],
} }
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
target_features.append(target_feature) target_features.append(target_feature)
kl_features.append(kl_feature) kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"]) kto_tags.append(feature["kto_tags"])
@ -196,7 +174,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"] batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in batch: if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"] batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags) batch["kto_tags"] = torch.tensor(kto_tags)

View File

@ -14,7 +14,7 @@
import os import os
import sys import sys
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk from datasets import DatasetDict, load_dataset, load_from_disk
@ -180,7 +180,13 @@ def _get_preprocessed_dataset(
desc="Running tokenizer on dataset", desc="Running tokenizer on dataset",
) )
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) dataset = dataset.map(
preprocess_func,
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=column_names,
**kwargs,
)
if training_args.should_log: if training_args.should_log:
try: try:
@ -202,7 +208,7 @@ def get_dataset(
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule": ) -> Tuple["DatasetModule", "Template"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
@ -273,4 +279,4 @@ def get_dataset(
if "validation" in dataset_dict: if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["eval_dataset"] = dataset_dict["validation"]
return dataset_module return dataset_module, template

View File

@ -1,5 +1,6 @@
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
from PIL.Image import Image from PIL.Image import Image
from transformers import ProcessorMixin from transformers import ProcessorMixin
@ -9,34 +10,53 @@ from ..extras.packages import is_pillow_available
if is_pillow_available(): if is_pillow_available():
import torch
from PIL import Image from PIL import Image
from PIL.Image import Image as ImageObject
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image as ImageObject import torch
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
class EncodedImage(TypedDict):
path: Optional[str]
bytes: Optional[bytes]
def _regularize_images(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> List["ImageObject"]: ImageInput = Union[str, EncodedImage, ImageObject]
def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]:
r""" r"""
Regularizes images to avoid error. Including resizing and mode convert. Regularizes images to avoid error. Including reading, resizing and converting.
""" """
images = images[:]
image_resolution = getattr(processor, "image_resolution", 512) image_resolution = getattr(processor, "image_resolution", 512)
for i in range(len(images)): results = []
if max(images[i].width, images[i].height) > image_resolution: for image in images:
factor = image_resolution / max(images[i].width, images[i].height) if isinstance(image, str):
images[i] = images[i].resize((int(images[i].width * factor), int(images[i].height * factor))) image = Image.open(image)
elif isinstance(image, dict):
if image["bytes"] is not None:
image = Image.open(BytesIO(image["bytes"]))
else:
image = Image.open(image["path"])
if images[i].mode != "RGB": if not isinstance(image, ImageObject):
images[i] = images[i].convert("RGB") raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
return images if max(image.width, image.height) > image_resolution:
factor = image_resolution / max(image.width, image.height)
image = image.resize((int(image.width * factor), int(image.height * factor)))
if image.mode != "RGB":
image = image.convert("RGB")
results.append(image)
return results
def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
r""" r"""
Processes visual inputs. Processes visual inputs.
@ -53,26 +73,27 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
if len(images) != 0: if len(images) != 0:
images = _regularize_images(images, processor) images = _regularize_images(images, processor)
image_inputs = image_processor(images=images, return_tensors="pt") image_inputs = image_processor(images=images, return_tensors="pt")
else: # add NoneType for fake images else:
image = Image.new("RGB", (64, 64), (255, 255, 255)) image_inputs = {}
image_inputs = image_processor(images=[image], return_tensors="pt")
image_inputs = {key: None for key in image_inputs.keys()}
return image_inputs return image_inputs
def _get_paligemma_token_type_ids( def _get_paligemma_token_type_ids(
images: Sequence["ImageObject"], input_len: int, processor: "ProcessorMixin" imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
) -> List[List[int]]: ) -> List[List[int]]:
r""" r"""
Gets paligemma token type ids for computing loss. Gets paligemma token type ids for computing loss.
Returns: Returns:
token_type_ids: shape (1, seq_len) batch_token_type_ids: shape (batch_size, sequence_length)
""" """
num_images = len(images) batch_token_type_ids = []
image_seqlen = num_images * getattr(processor, "image_seqlen") for imglen, seqlen in zip(imglens, seqlens):
return [[0] * image_seqlen + [1] * (input_len - image_seqlen)] image_seqlen = imglen * getattr(processor, "image_seqlen")
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
return batch_token_type_ids
class BasePlugin: class BasePlugin:
@ -82,7 +103,7 @@ class BasePlugin:
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"], images: Sequence["ImageInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
r""" r"""
@ -94,7 +115,7 @@ class BasePlugin:
self, self,
input_ids: List[int], input_ids: List[int],
labels: Optional[List[int]], labels: Optional[List[int]],
images: Sequence["ImageObject"], images: Sequence["ImageInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> Tuple[List[int], Optional[List[int]]]:
@ -105,10 +126,11 @@ class BasePlugin:
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageObject"], images: Sequence["ImageInput"],
feature_seqlens: Dict[str, int], imglens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r""" r"""
Builds batched multimodal inputs for VLMs. Builds batched multimodal inputs for VLMs.
""" """
@ -119,31 +141,32 @@ class LlavaPlugin(BasePlugin):
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"], images: Sequence["ImageInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
num_images = 0 num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") image_seqlen = getattr(processor, "image_seqlen")
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
num_images += 1 num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen) message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
if len(images) != num_images: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages return messages
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageObject"], images: Sequence["ImageInput"],
feature_seqlens: Dict[str, int], imglens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
return _get_mm_inputs(images, processor) return _get_mm_inputs(images, processor)
@ -151,20 +174,20 @@ class PaliGemmaPlugin(BasePlugin):
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"], images: Sequence["ImageInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
num_images = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
num_images += 1 num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
message["content"] = content.replace("{{image}}", "") message["content"] = content.replace("{{image}}", "")
if len(images) != num_images: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages return messages
@ -173,7 +196,7 @@ class PaliGemmaPlugin(BasePlugin):
self, self,
input_ids: List[int], input_ids: List[int],
labels: Optional[List[int]], labels: Optional[List[int]],
images: Sequence["ImageObject"], images: Sequence["ImageInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> Tuple[List[int], Optional[List[int]]]:
@ -188,14 +211,13 @@ class PaliGemmaPlugin(BasePlugin):
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageObject"], images: Sequence["ImageInput"],
feature_seqlens: Dict[str, int], imglens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
mm_inputs = _get_mm_inputs(images, processor) mm_inputs = _get_mm_inputs(images, processor)
for feature_name, feature_length in feature_seqlens.items(): mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
mm_inputs[feature_name] = _get_paligemma_token_type_ids(images, feature_length, processor)
return mm_inputs return mm_inputs
@ -203,7 +225,7 @@ class Qwen2vlPlugin(BasePlugin):
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"], images: Sequence["ImageInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
@ -213,36 +235,37 @@ class Qwen2vlPlugin(BasePlugin):
else: else:
image_grid_thw = [] image_grid_thw = []
num_images = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_images >= len(image_grid_thw): if num_image_tokens >= len(image_grid_thw):
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER)) raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format( "<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[num_images].prod() // merge_length) self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
), ),
1, 1,
) )
num_images += 1 num_image_tokens += 1
message["content"] = content message["content"] = content
if len(images) != num_images: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages return messages
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageObject"], images: Sequence["ImageInput"],
feature_seqlens: Dict[str, int], imglens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
return _get_mm_inputs(images, processor) return _get_mm_inputs(images, processor)

View File

@ -21,10 +21,10 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput
from ..template import Template from ..template import Template
@ -37,12 +37,12 @@ def _encode_feedback_example(
kl_response: Sequence[Dict[str, str]], kl_response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["Image"], images: Sequence["ImageInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool, Dict[str, Any]]: ) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if response[0]["content"]: # desired example if response[0]["content"]: # desired example
kto_tag = True kto_tag = True
messages = prompt + [response[0]] messages = prompt + [response[0]]
@ -78,15 +78,7 @@ def _encode_feedback_example(
labels = [IGNORE_INDEX] * source_len + response_ids labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
extra_inputs = template.mm_plugin.get_mm_inputs( return input_ids, labels, kl_input_ids, kl_labels, kto_tag
images=images,
feature_seqlens={
"token_type_ids": len(input_ids),
"kl_token_type_ids": len(kl_input_ids),
},
processor=processor,
)
return input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs
def preprocess_feedback_dataset( def preprocess_feedback_dataset(
@ -97,20 +89,20 @@ def preprocess_feedback_dataset(
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1] kl_response = examples["_response"][::-1]
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs = _encode_feedback_example( input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
kl_response=kl_response[i], kl_response=kl_response[i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["images"][i], images=examples["_images"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -123,8 +115,7 @@ def preprocess_feedback_dataset(
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels) model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag) model_inputs["kto_tags"].append(kto_tag)
for key, value in extra_inputs.items(): model_inputs["images"].append(examples["_images"][i])
model_inputs[key].append(value)
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num undesirable_num = len(model_inputs["kto_tags"]) - desirable_num

View File

@ -21,10 +21,10 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput
from ..template import Template from ..template import Template
@ -36,12 +36,12 @@ def _encode_pairwise_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["Image"], images: Sequence["ImageInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], Dict[str, Any]]: ) -> Tuple[List[int], List[int], List[int], List[int]]:
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor) chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor)
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor) rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor)
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
@ -62,15 +62,7 @@ def _encode_pairwise_example(
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
extra_inputs = template.mm_plugin.get_mm_inputs( return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
images=images,
feature_seqlens={
"chosen_token_type_ids": len(chosen_input_ids),
"rejected_token_type_ids": len(rejected_input_ids),
},
processor=processor,
)
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs
def preprocess_pairwise_dataset( def preprocess_pairwise_dataset(
@ -82,17 +74,17 @@ def preprocess_pairwise_dataset(
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs = _encode_pairwise_example( chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["images"][i], images=examples["_images"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -104,8 +96,7 @@ def preprocess_pairwise_dataset(
model_inputs["rejected_input_ids"].append(rejected_input_ids) model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels) model_inputs["rejected_labels"].append(rejected_labels)
for key, value in extra_inputs.items(): model_inputs["images"].append(examples["_images"][i])
model_inputs[key].append(value)
return model_inputs return model_inputs

View File

@ -30,7 +30,7 @@ def preprocess_pretrain_dataset(
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]] text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not data_args.packing: if not data_args.packing:
if data_args.template == "gemma": if data_args.template == "gemma":

View File

@ -21,10 +21,10 @@ from .processor_utils import greedy_knapsack, infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput
from ..template import Template from ..template import Template
@ -36,14 +36,14 @@ def _encode_supervised_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["Image"], images: Sequence["ImageInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
train_on_prompt: bool, train_on_prompt: bool,
mask_history: bool, mask_history: bool,
) -> Tuple[List[int], List[int], Dict[str, Any]]: ) -> Tuple[List[int], List[int]]:
messages = template.mm_plugin.process_messages(prompt + response, images, processor) messages = template.mm_plugin.process_messages(prompt + response, images, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor) input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
@ -83,10 +83,7 @@ def _encode_supervised_example(
input_ids += [tokenizer.eos_token_id] input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
extra_inputs = template.mm_plugin.get_mm_inputs( return input_ids, labels
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
)
return input_ids, labels, extra_inputs
def preprocess_supervised_dataset( def preprocess_supervised_dataset(
@ -99,17 +96,17 @@ def preprocess_supervised_dataset(
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair. # for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue continue
input_ids, labels, extra_inputs = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["images"][i], images=examples["_images"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -120,8 +117,7 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
for key, value in extra_inputs.items(): model_inputs["images"].append(examples["_images"][i])
model_inputs[key].append(value)
return model_inputs return model_inputs
@ -143,17 +139,17 @@ def preprocess_packed_supervised_dataset(
batch_input_ids, batch_labels = [], [] batch_input_ids, batch_labels = [], []
lengths = [] lengths = []
length2indexes = defaultdict(list) length2indexes = defaultdict(list)
for i in range(len(examples["prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue continue
input_ids, labels, _ = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["images"][i], images=examples["_images"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=None, processor=None,
@ -199,6 +195,7 @@ def preprocess_packed_supervised_dataset(
model_inputs["input_ids"].append(packed_input_ids) model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels) model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(examples["_images"][i])
return model_inputs return model_inputs

View File

@ -21,10 +21,10 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput
from ..template import Template from ..template import Template
@ -36,12 +36,12 @@ def _encode_unsupervised_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["Image"], images: Sequence["ImageInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int], Dict[str, Any]]: ) -> Tuple[List[int], List[int]]:
if len(response) == 1: if len(response) == 1:
messages = prompt + response messages = prompt + response
else: else:
@ -56,10 +56,7 @@ def _encode_unsupervised_example(
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len) source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len] input_ids = input_ids[:source_len]
labels = labels[:target_len] labels = labels[:target_len]
extra_inputs = template.mm_plugin.get_mm_inputs( return input_ids, labels
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
)
return input_ids, labels, extra_inputs
def preprocess_unsupervised_dataset( def preprocess_unsupervised_dataset(
@ -71,17 +68,17 @@ def preprocess_unsupervised_dataset(
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["prompt"][i]) % 2 != 1: if len(examples["_prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
continue continue
input_ids, labels, extra_inputs = _encode_unsupervised_example( input_ids, labels = _encode_unsupervised_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["images"][i], images=examples["_images"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -90,8 +87,7 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
for key, value in extra_inputs.items(): model_inputs["images"].append(examples["_images"][i])
model_inputs[key].append(value)
return model_inputs return model_inputs

View File

@ -73,6 +73,10 @@ class DataArguments:
default=False, default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}, metadata={"help": "Overwrite the cached training and evaluation sets."},
) )
preprocessing_batch_size: int = field(
default=1000,
metadata={"help": "The number of examples in one group in pre-processing."},
)
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of processes to use for the pre-processing."}, metadata={"help": "The number of processes to use for the pre-processing."},

View File

@ -41,13 +41,14 @@ def run_dpo(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) dataset_module, template = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding( data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer, template=template,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
) )
# Create reference model # Create reference model
@ -60,7 +61,7 @@ def run_dpo(
ref_model = None ref_model = None
# Update arguments # Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer # Initialize our Trainer
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(

View File

@ -41,13 +41,14 @@ def run_kto(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) dataset_module, template = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = KTODataCollatorWithPadding( data_collator = KTODataCollatorWithPadding(
tokenizer=tokenizer, template=template,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
) )
# Create reference model # Create reference model
@ -57,7 +58,7 @@ def run_kto(
ref_model = create_ref_model(model_args, finetuning_args) ref_model = create_ref_model(model_args, finetuning_args)
# Update arguments # Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer # Initialize our Trainer
trainer = CustomKTOTrainer( trainer = CustomKTOTrainer(

View File

@ -41,11 +41,11 @@ def run_ppo(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module) dataset_module, template = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = MultiModalDataCollatorForSeq2Seq(tokenizer=tokenizer) data_collator = MultiModalDataCollatorForSeq2Seq(template=template, **tokenizer_module)
# Create reference model and reward model # Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)

View File

@ -42,7 +42,7 @@ def run_pt(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) dataset_module, _ = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

View File

@ -41,12 +41,12 @@ def run_rm(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) dataset_module, template = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module)
# Update arguments # Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer # Initialize our Trainer
trainer = PairwiseTrainer( trainer = PairwiseTrainer(

View File

@ -43,24 +43,26 @@ def run_sft(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) dataset_module, template = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if getattr(model, "is_quantized", False) and not training_args.do_train: if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
data_collator = SFTDataCollatorWith4DAttentionMask( data_collator = SFTDataCollatorWith4DAttentionMask(
tokenizer=tokenizer, template=template,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn, block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None), attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype, compute_dtype=model_args.compute_dtype,
**tokenizer_module,
) )
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Metric utils # Metric utils
metric_module = {} metric_module = {}

View File

@ -105,7 +105,7 @@ def load_reference_model(
def load_train_dataset(**kwargs) -> "Dataset": def load_train_dataset(**kwargs) -> "Dataset":
model_args, data_args, training_args, _, _ = get_train_args(kwargs) model_args, data_args, training_args, _, _ = get_train_args(kwargs)
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
dataset_module = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module) dataset_module, _ = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module)
return dataset_module["train_dataset"] return dataset_module["train_dataset"]

View File

@ -47,11 +47,15 @@ IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
NO_IMAGES = [] NO_IMAGES = []
IMGLENS = [1]
NO_IMGLENS = [0]
INPUT_IDS = [0, 1, 2, 3, 4] INPUT_IDS = [0, 1, 2, 3, 4]
LABELS = [0, 1, 2, 3, 4] LABELS = [0, 1, 2, 3, 4]
FEATURE_SEQLENS = {"token_type_ids": 1024} SEQLENS = [1024]
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
@ -80,11 +84,11 @@ def test_base_plugin():
# test mm_messages # test mm_messages
assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {}) _is_close(base_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), {})
# test text_messages # test text_messages
assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {}) _is_close(base_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
def test_llava_plugin(): def test_llava_plugin():
@ -101,11 +105,11 @@ def test_llava_plugin():
# test mm_messages # test mm_messages
assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) _is_close(llava_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
# test text_messages # test text_messages
assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {"pixel_values": None}) _is_close(llava_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@ -128,7 +132,7 @@ def test_paligemma_plugin():
expected_input_ids, expected_input_ids,
expected_labels, expected_labels,
) )
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) _is_close(paligemma_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
# test text_messages # test text_messages
assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == ( assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
@ -136,8 +140,8 @@ def test_paligemma_plugin():
LABELS, LABELS,
) )
_is_close( _is_close(
paligemma_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), paligemma_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor),
{"pixel_values": None, "token_type_ids": [[1] * 1024]}, {"token_type_ids": [[1] * 1024]},
) )
@ -158,11 +162,8 @@ def test_qwen2_vl_plugin():
# test mm_messages # test mm_messages
assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) _is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
# test text_messages # test text_messages
assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS) assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close( _is_close(qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
{"pixel_values": None, "image_grid_thw": None},
)