mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
lazy image load
Former-commit-id: 47ea97fb1ba77de2e8a561904aa8fdc27c3f5025
This commit is contained in:
parent
5ef58eb655
commit
22deca0e9e
@ -156,7 +156,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
if image is not None:
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=[image], feature_seqlens={"token_type_ids": prompt_length}, processor=processor
|
||||
images=[image], imglens=[1], seqlens=[prompt_length], processor=processor
|
||||
)
|
||||
for key, value in mm_inputs.items():
|
||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
||||
|
@ -14,9 +14,7 @@
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
||||
|
||||
from datasets import Features
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role
|
||||
@ -27,16 +25,24 @@ if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .mm_plugin import ImageInput
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(images: Sequence[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||
def _convert_images(
|
||||
images: Sequence["ImageInput"],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Optional[List["ImageInput"]]:
|
||||
r"""
|
||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||
"""
|
||||
if len(images) == 0:
|
||||
return None
|
||||
|
||||
images = images[:]
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for i in range(len(images)):
|
||||
@ -47,66 +53,67 @@ def _convert_images(images: Sequence[Any], dataset_attr: "DatasetAttr", data_arg
|
||||
|
||||
|
||||
def convert_alpaca(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
example: Dict[str, Any],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, Any]:
|
||||
r"""
|
||||
Converts alpaca format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
prompt = []
|
||||
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
|
||||
for old_prompt, old_response in example[dataset_attr.history]:
|
||||
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||
|
||||
query = []
|
||||
if dataset_attr.prompt and example[dataset_attr.prompt]:
|
||||
query.append(example[dataset_attr.prompt])
|
||||
|
||||
if dataset_attr.query and example[dataset_attr.query]:
|
||||
query.append(example[dataset_attr.query])
|
||||
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
|
||||
if example[dataset_attr.kto_tag]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(example[dataset_attr.chosen], str)
|
||||
and isinstance(example[dataset_attr.rejected], str)
|
||||
): # pairwise example
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
|
||||
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
|
||||
]
|
||||
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
|
||||
else: # unsupervised
|
||||
response = []
|
||||
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||
|
||||
content = []
|
||||
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
||||
content.append(examples[dataset_attr.prompt][i])
|
||||
|
||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||
content.append(examples[dataset_attr.query][i])
|
||||
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
if examples[dataset_attr.kto_tag][i]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], str)
|
||||
and isinstance(examples[dataset_attr.rejected][i], str)
|
||||
): # pairwise example
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else: # unsupervised
|
||||
response = []
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
|
||||
return outputs
|
||||
output = {
|
||||
"_prompt": prompt,
|
||||
"_response": response,
|
||||
"_system": example[dataset_attr.system] if dataset_attr.system else "",
|
||||
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
||||
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
def convert_sharegpt(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
example: Dict[str, Any],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, Any]:
|
||||
r"""
|
||||
Converts sharegpt format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
@ -117,74 +124,77 @@ def convert_sharegpt(
|
||||
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||
accept_tags = (odd_tags, even_tags)
|
||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||
if len(messages) == 0:
|
||||
continue
|
||||
messages = example[dataset_attr.messages]
|
||||
if (
|
||||
dataset_attr.system_tag
|
||||
and len(messages) != 0
|
||||
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
|
||||
):
|
||||
system = messages[0][dataset_attr.content_tag]
|
||||
messages = messages[1:]
|
||||
else:
|
||||
system = example[dataset_attr.system] if dataset_attr.system else ""
|
||||
|
||||
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
||||
system = messages[0][dataset_attr.content_tag]
|
||||
messages = messages[1:]
|
||||
else:
|
||||
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
||||
|
||||
aligned_messages = []
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
aligned_messages.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning("Invalid message count in {}.".format(messages))
|
||||
aligned_messages = []
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
if examples[dataset_attr.kto_tag][i]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], dict)
|
||||
and isinstance(examples[dataset_attr.rejected][i], dict)
|
||||
): # pairwise example
|
||||
chosen = examples[dataset_attr.chosen][i]
|
||||
rejected = examples[dataset_attr.rejected][i]
|
||||
if (
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
||||
broken_data = True
|
||||
aligned_messages.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
prompt = aligned_messages
|
||||
response = [
|
||||
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
||||
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
||||
]
|
||||
else: # normal example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning("Invalid message count in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
if broken_data:
|
||||
logger.warning("Skipping this abnormal example.")
|
||||
continue
|
||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
if example[dataset_attr.kto_tag]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(example[dataset_attr.chosen], dict)
|
||||
and isinstance(example[dataset_attr.rejected], dict)
|
||||
): # pairwise example
|
||||
chosen = example[dataset_attr.chosen]
|
||||
rejected = example[dataset_attr.rejected]
|
||||
if (
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
||||
broken_data = True
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
prompt = aligned_messages
|
||||
response = [
|
||||
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
||||
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
||||
]
|
||||
else: # normal example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
return outputs
|
||||
if broken_data:
|
||||
logger.warning("Skipping this abnormal example.")
|
||||
prompt, response = [], []
|
||||
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
output = {
|
||||
"_prompt": prompt,
|
||||
"_response": response,
|
||||
"_system": system,
|
||||
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
||||
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
def align_dataset(
|
||||
@ -195,11 +205,11 @@ def align_dataset(
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
system: "..."
|
||||
tools: "...",
|
||||
images: [],
|
||||
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
_system: "..."
|
||||
_tools: "...",
|
||||
_images: [],
|
||||
"""
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
||||
@ -207,19 +217,6 @@ def align_dataset(
|
||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
|
||||
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
features = Features.from_dict(
|
||||
{
|
||||
"prompt": [
|
||||
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||
],
|
||||
"response": [
|
||||
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||
],
|
||||
"system": {"dtype": "string", "_type": "Value"},
|
||||
"tools": {"dtype": "string", "_type": "Value"},
|
||||
"images": [{"_type": "Image"}],
|
||||
}
|
||||
)
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
@ -230,8 +227,7 @@ def align_dataset(
|
||||
|
||||
return dataset.map(
|
||||
convert_func,
|
||||
batched=True,
|
||||
batched=False,
|
||||
remove_columns=column_names,
|
||||
features=features,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -16,12 +16,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Literal, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from .template import Template
|
||||
|
||||
|
||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||
r"""
|
||||
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||
@ -65,41 +71,29 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator that supports VLMs.
|
||||
|
||||
Features should contain input_ids, attention_mask, labels and images.
|
||||
"""
|
||||
|
||||
template: Optional["Template"] = None
|
||||
processor: Optional["ProcessorMixin"] = None
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
if "token_type_ids" in features[0].keys():
|
||||
for feature in features:
|
||||
feature["token_type_ids"] = feature["token_type_ids"][0]
|
||||
batch_images, batch_imglens, batch_seqlens = [], [], []
|
||||
for feature in features:
|
||||
images = feature.pop("images") or [] # avoid NoneType
|
||||
batch_images.extend(images)
|
||||
batch_imglens.append(len(images))
|
||||
batch_seqlens.append(len(feature["input_ids"]))
|
||||
|
||||
extra_features = {}
|
||||
if "pixel_values" in features[0].keys():
|
||||
pixel_values = []
|
||||
for feature in features:
|
||||
if feature["pixel_values"] is None:
|
||||
pixel_values.append(torch.zeros(0, dtype=torch.float))
|
||||
else:
|
||||
pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float))
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(batch_images, batch_imglens, batch_seqlens, self.processor)
|
||||
if "token_type_ids" in mm_inputs:
|
||||
token_type_ids = mm_inputs.pop("token_type_ids")
|
||||
for i, feature in enumerate(features):
|
||||
feature["token_type_ids"] = token_type_ids[i]
|
||||
|
||||
extra_features["pixel_values"] = torch.cat(pixel_values, dim=0)
|
||||
if extra_features["pixel_values"].numel() == 0:
|
||||
extra_features["pixel_values"] = None
|
||||
|
||||
if "image_grid_thw" in features[0].keys():
|
||||
image_grid_thw = []
|
||||
for feature in features:
|
||||
if feature["image_grid_thw"] is None:
|
||||
image_grid_thw.append(torch.zeros(0, dtype=torch.long))
|
||||
else:
|
||||
image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long))
|
||||
|
||||
extra_features["image_grid_thw"] = torch.cat(image_grid_thw, dim=0)
|
||||
if extra_features["image_grid_thw"].numel() == 0:
|
||||
extra_features["image_grid_thw"] = None
|
||||
|
||||
features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features]
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
features.update({key: value for key, value in extra_features.items() if value is not None})
|
||||
features.update(mm_inputs)
|
||||
return features
|
||||
|
||||
|
||||
@ -141,16 +135,8 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"input_ids": feature["{}_input_ids".format(key)],
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
"images": feature["images"],
|
||||
}
|
||||
if "{}_token_type_ids".format(key) in feature:
|
||||
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
||||
|
||||
if "pixel_values" in feature: # image data are same for chosen and rejected
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in feature:
|
||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||
|
||||
concatenated_features.append(target_feature)
|
||||
|
||||
return super().__call__(concatenated_features)
|
||||
@ -171,22 +157,14 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"input_ids": feature["input_ids"],
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
"images": feature["images"],
|
||||
}
|
||||
kl_feature = {
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
"images": feature["images"],
|
||||
}
|
||||
if "token_type_ids" in feature:
|
||||
target_feature["token_type_ids"] = feature["token_type_ids"]
|
||||
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
||||
|
||||
if "pixel_values" in feature:
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in feature:
|
||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||
|
||||
target_features.append(target_feature)
|
||||
kl_features.append(kl_feature)
|
||||
kto_tags.append(feature["kto_tags"])
|
||||
@ -196,7 +174,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
if "token_type_ids" in batch:
|
||||
if "token_type_ids" in kl_batch:
|
||||
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
||||
|
||||
batch["kto_tags"] = torch.tensor(kto_tags)
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||
@ -180,7 +180,13 @@ def _get_preprocessed_dataset(
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
dataset = dataset.map(
|
||||
preprocess_func,
|
||||
batched=True,
|
||||
batch_size=data_args.preprocessing_batch_size,
|
||||
remove_columns=column_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
@ -202,7 +208,7 @@ def get_dataset(
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> "DatasetModule":
|
||||
) -> Tuple["DatasetModule", "Template"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
@ -273,4 +279,4 @@ def get_dataset(
|
||||
if "validation" in dataset_dict:
|
||||
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||
|
||||
return dataset_module
|
||||
return dataset_module, template
|
||||
|
@ -1,5 +1,6 @@
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from transformers import ProcessorMixin
|
||||
@ -9,34 +10,53 @@ from ..extras.packages import is_pillow_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image as ImageObject
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
class EncodedImage(TypedDict):
|
||||
path: Optional[str]
|
||||
bytes: Optional[bytes]
|
||||
|
||||
def _regularize_images(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> List["ImageObject"]:
|
||||
ImageInput = Union[str, EncodedImage, ImageObject]
|
||||
|
||||
|
||||
def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including resizing and mode convert.
|
||||
Regularizes images to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
images = images[:]
|
||||
image_resolution = getattr(processor, "image_resolution", 512)
|
||||
for i in range(len(images)):
|
||||
if max(images[i].width, images[i].height) > image_resolution:
|
||||
factor = image_resolution / max(images[i].width, images[i].height)
|
||||
images[i] = images[i].resize((int(images[i].width * factor), int(images[i].height * factor)))
|
||||
results = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
elif isinstance(image, dict):
|
||||
if image["bytes"] is not None:
|
||||
image = Image.open(BytesIO(image["bytes"]))
|
||||
else:
|
||||
image = Image.open(image["path"])
|
||||
|
||||
if images[i].mode != "RGB":
|
||||
images[i] = images[i].convert("RGB")
|
||||
if not isinstance(image, ImageObject):
|
||||
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
||||
|
||||
return images
|
||||
if max(image.width, image.height) > image_resolution:
|
||||
factor = image_resolution / max(image.width, image.height)
|
||||
image = image.resize((int(image.width * factor), int(image.height * factor)))
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
results.append(image)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs.
|
||||
|
||||
@ -53,26 +73,27 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
|
||||
if len(images) != 0:
|
||||
images = _regularize_images(images, processor)
|
||||
image_inputs = image_processor(images=images, return_tensors="pt")
|
||||
else: # add NoneType for fake images
|
||||
image = Image.new("RGB", (64, 64), (255, 255, 255))
|
||||
image_inputs = image_processor(images=[image], return_tensors="pt")
|
||||
image_inputs = {key: None for key in image_inputs.keys()}
|
||||
else:
|
||||
image_inputs = {}
|
||||
|
||||
return image_inputs
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(
|
||||
images: Sequence["ImageObject"], input_len: int, processor: "ProcessorMixin"
|
||||
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
|
||||
) -> List[List[int]]:
|
||||
r"""
|
||||
Gets paligemma token type ids for computing loss.
|
||||
|
||||
Returns:
|
||||
token_type_ids: shape (1, seq_len)
|
||||
batch_token_type_ids: shape (batch_size, sequence_length)
|
||||
"""
|
||||
num_images = len(images)
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
||||
return [[0] * image_seqlen + [1] * (input_len - image_seqlen)]
|
||||
batch_token_type_ids = []
|
||||
for imglen, seqlen in zip(imglens, seqlens):
|
||||
image_seqlen = imglen * getattr(processor, "image_seqlen")
|
||||
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
|
||||
|
||||
return batch_token_type_ids
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
@ -82,7 +103,7 @@ class BasePlugin:
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageObject"],
|
||||
images: Sequence["ImageInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""
|
||||
@ -94,7 +115,7 @@ class BasePlugin:
|
||||
self,
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
images: Sequence["ImageObject"],
|
||||
images: Sequence["ImageInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
@ -105,10 +126,11 @@ class BasePlugin:
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
images: Sequence["ImageInput"],
|
||||
imglens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
r"""
|
||||
Builds batched multimodal inputs for VLMs.
|
||||
"""
|
||||
@ -119,31 +141,32 @@ class LlavaPlugin(BasePlugin):
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageObject"],
|
||||
images: Sequence["ImageInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
num_images = 0
|
||||
num_image_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seqlen")
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_images += 1
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
|
||||
if len(images) != num_images:
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
return messages
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
images: Sequence["ImageInput"],
|
||||
imglens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
return _get_mm_inputs(images, processor)
|
||||
|
||||
|
||||
@ -151,20 +174,20 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageObject"],
|
||||
images: Sequence["ImageInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
num_images = 0
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_images += 1
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", "")
|
||||
|
||||
if len(images) != num_images:
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
return messages
|
||||
@ -173,7 +196,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
self,
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
images: Sequence["ImageObject"],
|
||||
images: Sequence["ImageInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
@ -188,14 +211,13 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
images: Sequence["ImageInput"],
|
||||
imglens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
mm_inputs = _get_mm_inputs(images, processor)
|
||||
for feature_name, feature_length in feature_seqlens.items():
|
||||
mm_inputs[feature_name] = _get_paligemma_token_type_ids(images, feature_length, processor)
|
||||
|
||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@ -203,7 +225,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageObject"],
|
||||
images: Sequence["ImageInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
@ -213,36 +235,37 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
else:
|
||||
image_grid_thw = []
|
||||
|
||||
num_images = 0
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_images >= len(image_grid_thw):
|
||||
if num_image_tokens >= len(image_grid_thw):
|
||||
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
"<|vision_start|>{}<|vision_end|>".format(
|
||||
self.image_token * (image_grid_thw[num_images].prod() // merge_length)
|
||||
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
|
||||
),
|
||||
1,
|
||||
)
|
||||
num_images += 1
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_images:
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
return messages
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
images: Sequence["ImageInput"],
|
||||
imglens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
return _get_mm_inputs(images, processor)
|
||||
|
||||
|
||||
|
@ -21,10 +21,10 @@ from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import ImageInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
@ -37,12 +37,12 @@ def _encode_feedback_example(
|
||||
kl_response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["Image"],
|
||||
images: Sequence["ImageInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool, Dict[str, Any]]:
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = prompt + [response[0]]
|
||||
@ -78,15 +78,7 @@ def _encode_feedback_example(
|
||||
labels = [IGNORE_INDEX] * source_len + response_ids
|
||||
kl_input_ids = kl_prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
||||
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=images,
|
||||
feature_seqlens={
|
||||
"token_type_ids": len(input_ids),
|
||||
"kl_token_type_ids": len(kl_input_ids),
|
||||
},
|
||||
processor=processor,
|
||||
)
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||
|
||||
|
||||
def preprocess_feedback_dataset(
|
||||
@ -97,20 +89,20 @@ def preprocess_feedback_dataset(
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[Any]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["response"][::-1]
|
||||
kl_response = examples["_response"][::-1]
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
continue
|
||||
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs = _encode_feedback_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
kl_response=kl_response[i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -123,8 +115,7 @@ def preprocess_feedback_dataset(
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
for key, value in extra_inputs.items():
|
||||
model_inputs[key].append(value)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
|
@ -21,10 +21,10 @@ from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import ImageInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
@ -36,12 +36,12 @@ def _encode_pairwise_example(
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["Image"],
|
||||
images: Sequence["ImageInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], Dict[str, Any]]:
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor)
|
||||
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor)
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||
@ -62,15 +62,7 @@ def _encode_pairwise_example(
|
||||
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
||||
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=images,
|
||||
feature_seqlens={
|
||||
"chosen_token_type_ids": len(chosen_input_ids),
|
||||
"rejected_token_type_ids": len(rejected_input_ids),
|
||||
},
|
||||
processor=processor,
|
||||
)
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
@ -82,17 +74,17 @@ def preprocess_pairwise_dataset(
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
continue
|
||||
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs = _encode_pairwise_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -104,8 +96,7 @@ def preprocess_pairwise_dataset(
|
||||
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||
model_inputs["rejected_labels"].append(rejected_labels)
|
||||
for key, value in extra_inputs.items():
|
||||
model_inputs[key].append(value)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -30,7 +30,7 @@ def preprocess_pretrain_dataset(
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
|
||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
|
||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
|
||||
|
||||
if not data_args.packing:
|
||||
if data_args.template == "gemma":
|
||||
|
@ -21,10 +21,10 @@ from .processor_utils import greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import ImageInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
@ -36,14 +36,14 @@ def _encode_supervised_example(
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["Image"],
|
||||
images: Sequence["ImageInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
train_on_prompt: bool,
|
||||
mask_history: bool,
|
||||
) -> Tuple[List[int], List[int], Dict[str, Any]]:
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
messages = template.mm_plugin.process_messages(prompt + response, images, processor)
|
||||
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor)
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||
@ -83,10 +83,7 @@ def _encode_supervised_example(
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
|
||||
)
|
||||
return input_ids, labels, extra_inputs
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_supervised_dataset(
|
||||
@ -99,17 +96,17 @@ def preprocess_supervised_dataset(
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
continue
|
||||
|
||||
input_ids, labels, extra_inputs = _encode_supervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -120,8 +117,7 @@ def preprocess_supervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
for key, value in extra_inputs.items():
|
||||
model_inputs[key].append(value)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
@ -143,17 +139,17 @@ def preprocess_packed_supervised_dataset(
|
||||
batch_input_ids, batch_labels = [], []
|
||||
lengths = []
|
||||
length2indexes = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
continue
|
||||
|
||||
input_ids, labels, _ = _encode_supervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=None,
|
||||
@ -199,6 +195,7 @@ def preprocess_packed_supervised_dataset(
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -21,10 +21,10 @@ from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import ImageInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
@ -36,12 +36,12 @@ def _encode_unsupervised_example(
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["Image"],
|
||||
images: Sequence["ImageInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], Dict[str, Any]]:
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if len(response) == 1:
|
||||
messages = prompt + response
|
||||
else:
|
||||
@ -56,10 +56,7 @@ def _encode_unsupervised_example(
|
||||
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
|
||||
input_ids = input_ids[:source_len]
|
||||
labels = labels[:target_len]
|
||||
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
|
||||
)
|
||||
return input_ids, labels, extra_inputs
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_unsupervised_dataset(
|
||||
@ -71,17 +68,17 @@ def preprocess_unsupervised_dataset(
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
continue
|
||||
|
||||
input_ids, labels, extra_inputs = _encode_unsupervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
input_ids, labels = _encode_unsupervised_example(
|
||||
prompt=examples["_prompt"][i],
|
||||
response=examples["_response"][i],
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -90,8 +87,7 @@ def preprocess_unsupervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
for key, value in extra_inputs.items():
|
||||
model_inputs[key].append(value)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -73,6 +73,10 @@ class DataArguments:
|
||||
default=False,
|
||||
metadata={"help": "Overwrite the cached training and evaluation sets."},
|
||||
)
|
||||
preprocessing_batch_size: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "The number of examples in one group in pre-processing."},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||
|
@ -41,13 +41,14 @@ def run_dpo(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
dataset_module, template = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
template=template,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
# Create reference model
|
||||
@ -60,7 +61,7 @@ def run_dpo(
|
||||
ref_model = None
|
||||
|
||||
# Update arguments
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomDPOTrainer(
|
||||
|
@ -41,13 +41,14 @@ def run_kto(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
|
||||
dataset_module, template = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = KTODataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
template=template,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
# Create reference model
|
||||
@ -57,7 +58,7 @@ def run_kto(
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
|
||||
# Update arguments
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomKTOTrainer(
|
||||
|
@ -41,11 +41,11 @@ def run_ppo(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||
dataset_module, template = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(tokenizer=tokenizer)
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, **tokenizer_module)
|
||||
|
||||
# Create reference model and reward model
|
||||
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
|
||||
|
@ -42,7 +42,7 @@ def run_pt(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
dataset_module, _ = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
|
@ -41,12 +41,12 @@ def run_rm(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
dataset_module, template = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module)
|
||||
|
||||
# Update arguments
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PairwiseTrainer(
|
||||
|
@ -43,24 +43,26 @@ def run_sft(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
dataset_module, template = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
tokenizer=tokenizer,
|
||||
template=template,
|
||||
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
block_diag_attn=model_args.block_diag_attn,
|
||||
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
||||
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
|
@ -105,7 +105,7 @@ def load_reference_model(
|
||||
def load_train_dataset(**kwargs) -> "Dataset":
|
||||
model_args, data_args, training_args, _, _ = get_train_args(kwargs)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module)
|
||||
dataset_module, _ = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module)
|
||||
return dataset_module["train_dataset"]
|
||||
|
||||
|
||||
|
@ -47,11 +47,15 @@ IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
|
||||
|
||||
NO_IMAGES = []
|
||||
|
||||
IMGLENS = [1]
|
||||
|
||||
NO_IMGLENS = [0]
|
||||
|
||||
INPUT_IDS = [0, 1, 2, 3, 4]
|
||||
|
||||
LABELS = [0, 1, 2, 3, 4]
|
||||
|
||||
FEATURE_SEQLENS = {"token_type_ids": 1024}
|
||||
SEQLENS = [1024]
|
||||
|
||||
|
||||
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
@ -80,11 +84,11 @@ def test_base_plugin():
|
||||
# test mm_messages
|
||||
assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {})
|
||||
_is_close(base_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), {})
|
||||
# test text_messages
|
||||
assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {})
|
||||
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
@ -101,11 +105,11 @@ def test_llava_plugin():
|
||||
# test mm_messages
|
||||
assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {"pixel_values": None})
|
||||
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@ -128,7 +132,7 @@ def test_paligemma_plugin():
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
|
||||
@ -136,8 +140,8 @@ def test_paligemma_plugin():
|
||||
LABELS,
|
||||
)
|
||||
_is_close(
|
||||
paligemma_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
|
||||
{"pixel_values": None, "token_type_ids": [[1] * 1024]},
|
||||
paligemma_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor),
|
||||
{"token_type_ids": [[1] * 1024]},
|
||||
)
|
||||
|
||||
|
||||
@ -158,11 +162,8 @@ def test_qwen2_vl_plugin():
|
||||
# test mm_messages
|
||||
assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(
|
||||
qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
|
||||
{"pixel_values": None, "image_grid_thw": None},
|
||||
)
|
||||
_is_close(qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
|
Loading…
x
Reference in New Issue
Block a user