fix mixed mm inputs and rlhf-v

Former-commit-id: 9967ccb3aef3ca557ad6eafb78c6c99866857008
This commit is contained in:
hiyouga 2024-09-01 20:52:47 +08:00
parent 34dc36462c
commit cb776752f6
20 changed files with 306 additions and 277 deletions

View File

@ -33,7 +33,7 @@ Dependency graph:
transformers>=4.41.2,<=4.44.3 transformers>=4.41.2,<=4.44.3
""" """
from .cli import VERSION from .extras.env import VERSION
__version__ = VERSION __version__ = VERSION

View File

@ -89,11 +89,9 @@ class HuggingfaceEngine(BaseEngine):
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"] system = system or generating_args["default_system"]
prompt_ids, _ = template.encode_oneturn( prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
)
if image is not None: if image is not None:
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor) prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, [image], tokenizer, processor)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)

View File

@ -124,9 +124,7 @@ class VllmEngine(BaseEngine):
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn( prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
if self.processor is not None and image is not None: # add image features if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")

View File

@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
from .collator import ( from .collator import (
CustomDataCollatorForSeq2Seq,
KTODataCollatorWithPadding, KTODataCollatorWithPadding,
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding, PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask, SFTDataCollatorWith4DAttentionMask,
) )
@ -24,8 +24,8 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [ __all__ = [
"CustomDataCollatorForSeq2Seq",
"KTODataCollatorWithPadding", "KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding", "PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask", "SFTDataCollatorWith4DAttentionMask",
"Role", "Role",

View File

@ -62,44 +62,49 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass @dataclass
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r""" r"""
Data collator for custom models (like Qwen2-VL). Data collator that supports VLMs.
""" """
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
image_grid_thw = None # TODO: better handle various VLMs if "token_type_ids" in features[0].keys():
if "image_grid_thw" in features[0]: for feature in features:
image_grid_thw_list = [ feature["token_type_ids"] = feature["token_type_ids"][0]
torch.Tensor(feature["image_grid_thw"]).long()
for feature in features
if feature["image_grid_thw"][0][0] > 0
]
pixel_values_list = [
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
]
if image_grid_thw_list:
image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
pixel_values = torch.cat(pixel_values_list, dim=0)
else:
image_grid_thw = None
pixel_values = None
features = [ extra_features = {}
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]} if "pixel_values" in features[0].keys():
for feature in features pixel_values = []
] for feature in features:
if feature["pixel_values"] is None:
pixel_values.append(torch.zeros(0, dtype=torch.float))
else:
pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float))
features = super().__call__(features) extra_features["pixel_values"] = torch.cat(pixel_values, dim=0)
if image_grid_thw is not None: if extra_features["pixel_values"].numel() == 0:
features["image_grid_thw"] = image_grid_thw extra_features["pixel_values"] = None
features["pixel_values"] = pixel_values
if "image_grid_thw" in features[0].keys():
image_grid_thw = []
for feature in features:
if feature["image_grid_thw"] is None:
image_grid_thw.append(torch.zeros(0, dtype=torch.long))
else:
image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long))
extra_features["image_grid_thw"] = torch.cat(pixel_values, dim=0)
if extra_features["image_grid_thw"].numel() == 0:
extra_features["image_grid_thw"] = None
features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update({key: value for key, value in extra_features.items() if value is not None})
return features return features
@dataclass @dataclass
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq): class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for 4d attention mask. Data collator for 4d attention mask.
""" """
@ -117,7 +122,7 @@ class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
""" """
@ -152,7 +157,7 @@ class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
@dataclass @dataclass
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq): class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for KTO data. Data collator for KTO data.
""" """

View File

@ -16,16 +16,16 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from .data_utils import SLOTS from .data_utils import SLOTS
from .tool_utils import DefaultToolUtils, GLM4ToolUtils from .tool_utils import get_tool_utils
@dataclass @dataclass
class Formatter(ABC): class Formatter(ABC):
slots: SLOTS = field(default_factory=list) slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default", "glm4"]] = None tool_format: Optional[str] = None
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: ... def apply(self, **kwargs) -> SLOTS: ...
@ -81,12 +81,7 @@ class StringFormatter(Formatter):
@dataclass @dataclass
class FunctionFormatter(Formatter): class FunctionFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
if self.tool_format == "default": self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
self.slots = DefaultToolUtils.get_function_slots() + self.slots
elif self.tool_format == "glm4":
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
@ -119,22 +114,15 @@ class FunctionFormatter(Formatter):
@dataclass @dataclass
class ToolFormatter(Formatter): class ToolFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
if self.tool_format == "default": self.tool_utils = get_tool_utils(self.tool_format)
self._tool_formatter = DefaultToolUtils.tool_formatter
self._tool_extractor = DefaultToolUtils.tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = GLM4ToolUtils.tool_formatter
self._tool_extractor = GLM4ToolUtils.tool_extractor
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
try: try:
tools = json.loads(content) tools = json.loads(content)
return [self._tool_formatter(tools) if len(tools) != 0 else ""] return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError: except json.JSONDecodeError:
return [""] return [""]
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
return self._tool_extractor(content) return self.tool_utils.tool_extractor(content)

View File

@ -1,3 +1,4 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from PIL.Image import Image from PIL.Image import Image
@ -27,32 +28,33 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
Returns: (qwen2-vl) Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim) pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
It holds num_patches == torch.prod(image_grid_thw) It holds num_patches == torch.prod(image_grid_thw)
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if len(images) != 0: if len(images) != 0:
image_inputs = image_processor(images=images, return_tensors="pt") image_inputs = image_processor(images=images, return_tensors="pt")
else: else: # add NoneType for fake images
image = Image.new("RGB", (56, 56), (255, 255, 255)) image = Image.new("RGB", (64, 64), (255, 255, 255))
image_inputs = image_processor(images=[image], return_tensors="pt") image_inputs = image_processor(images=[image], return_tensors="pt")
if "image_grid_thw" in image_inputs: # fake image for qwen2-vl image_inputs = {key: None for key in image_inputs.keys()}
image_inputs["image_grid_thw"][0][0] = 0
return image_inputs return image_inputs
def _get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]: def _get_paligemma_token_type_ids(
images: Sequence["ImageObject"], input_len: int, processor: "ProcessorMixin"
) -> List[List[int]]:
r""" r"""
Gets paligemma token type ids for computing loss. Gets paligemma token type ids for computing loss.
Returns: Returns:
token_type_ids: shape (1, seq_len) token_type_ids: shape (1, seq_len)
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") num_images = len(images)
image_seq_length: int = getattr(image_processor, "image_seq_length") image_seqlen = num_images * getattr(processor, "image_seqlen")
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)] return [[0] * image_seqlen + [1] * (input_len - image_seqlen)]
class BasePlugin: class BasePlugin:
@ -74,6 +76,7 @@ class BasePlugin:
self, self,
input_ids: List[int], input_ids: List[int],
labels: Optional[List[int]], labels: Optional[List[int]],
images: Sequence["ImageObject"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> Tuple[List[int], Optional[List[int]]]:
@ -93,18 +96,6 @@ class BasePlugin:
""" """
return {} return {}
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
r"""
Appends multimodal inputs to model inputs for VLMs.
"""
return
class LlavaPlugin(BasePlugin): class LlavaPlugin(BasePlugin):
def process_messages( def process_messages(
@ -113,21 +104,21 @@ class LlavaPlugin(BasePlugin):
images: Sequence["ImageObject"], images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
image_count = 0 num_images = 0
new_messages = [] image_seqlen = getattr(processor, "image_seqlen")
messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
image_count += 1 num_images += 1
if image_count > 1:
raise ValueError("Llava model only accepts one image per sample.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
content = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
new_messages.append({"role": message["role"], "content": content})
return new_messages if len(images) != num_images:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
def get_mm_inputs( def get_mm_inputs(
self, self,
@ -137,17 +128,6 @@ class LlavaPlugin(BasePlugin):
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return _get_mm_inputs(images, processor) return _get_mm_inputs(images, processor)
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
for key, value in mm_inputs.items():
model_inputs[key].append(value[0])
class PaliGemmaPlugin(BasePlugin): class PaliGemmaPlugin(BasePlugin):
def process_messages( def process_messages(
@ -156,34 +136,35 @@ class PaliGemmaPlugin(BasePlugin):
images: Sequence["ImageObject"], images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
image_count = 0 num_images = 0
new_messages = [] messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
image_count += 1 num_images += 1
if image_count > 1: content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
raise ValueError("PaliGemma model only accepts one image per sample.")
content = content.replace(IMAGE_PLACEHOLDER, "", 1) message["content"] = content.replace("{{image}}", "")
new_messages.append({"role": message["role"], "content": content}) if len(images) != num_images:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return new_messages return messages
def process_token_ids( def process_token_ids(
self, self,
input_ids: List[int], input_ids: List[int],
labels: Optional[List[int]], labels: Optional[List[int]],
images: Sequence["ImageObject"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> Tuple[List[int], Optional[List[int]]]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") num_images = len(images)
image_seq_length: int = getattr(image_processor, "image_seq_length") image_seqlen = num_images * getattr(processor, "image_seqlen")
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seq_length + input_ids input_ids = [image_token_id] * image_seqlen + input_ids
if labels is not None: if labels is not None:
labels = [IGNORE_INDEX] * image_seq_length + labels labels = [IGNORE_INDEX] * image_seqlen + labels
return input_ids, labels return input_ids, labels
@ -195,21 +176,10 @@ class PaliGemmaPlugin(BasePlugin):
) -> Dict[str, Any]: ) -> Dict[str, Any]:
mm_inputs = _get_mm_inputs(images, processor) mm_inputs = _get_mm_inputs(images, processor)
for feature_name, feature_length in feature_seqlens.items(): for feature_name, feature_length in feature_seqlens.items():
mm_inputs[feature_name] = _get_paligemma_token_type_ids(feature_length, processor) mm_inputs[feature_name] = _get_paligemma_token_type_ids(images, feature_length, processor)
return mm_inputs return mm_inputs
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
for key, value in mm_inputs.items():
model_inputs[key].append(value[0])
class Qwen2vlPlugin(BasePlugin): class Qwen2vlPlugin(BasePlugin):
def process_messages( def process_messages(
@ -223,23 +193,26 @@ class Qwen2vlPlugin(BasePlugin):
if len(images) > 0: if len(images) > 0:
image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"] image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"]
index = 0 num_images = 0
new_messages = [] messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format( "<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[index].prod() // merge_length) self.image_token * (image_grid_thw[num_images].prod() // merge_length)
), ),
1, 1,
) )
index += 1 num_images += 1
new_messages.append({"role": message["role"], "content": content}) message["content"] = content
return new_messages if len(images) != num_images:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
return messages
def get_mm_inputs( def get_mm_inputs(
self, self,
@ -249,17 +222,6 @@ class Qwen2vlPlugin(BasePlugin):
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return _get_mm_inputs(images, processor) return _get_mm_inputs(images, processor)
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
for key, value in mm_inputs.items():
model_inputs[key].append(value) # support multi-image
PLUGINS = { PLUGINS = {
"base": BasePlugin, "base": BasePlugin,
@ -270,7 +232,8 @@ PLUGINS = {
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin": def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
if name not in PLUGINS: plugin_class = PLUGINS.get(name, None)
raise ValueError("{} not found.".format(name)) if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name))
return PLUGINS[name](image_token) return plugin_class(image_token)

View File

@ -50,7 +50,7 @@ def get_preprocess_and_print_func(
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate: elif stage == "sft" and not do_generate:
if data_args.packing: if data_args.packing:
if data_args.neat_packing: if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs): def __init__(self, data, **kwargs):
@ -67,6 +67,7 @@ def get_preprocess_and_print_func(
preprocess_packed_supervised_dataset, preprocess_packed_supervised_dataset,
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor,
data_args=data_args, data_args=data_args,
) )
else: else:

View File

@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
@ -36,11 +37,12 @@ def _encode_feedback_example(
kl_response: Sequence[Dict[str, str]], kl_response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["Image"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]: ) -> Tuple[List[int], List[int], List[int], List[int], bool, Dict[str, Any]]:
if response[0]["content"]: # desired example if response[0]["content"]: # desired example
kto_tag = True kto_tag = True
messages = prompt + [response[0]] messages = prompt + [response[0]]
@ -53,6 +55,8 @@ def _encode_feedback_example(
else: else:
kl_messages = prompt + [kl_response[1]] kl_messages = prompt + [kl_response[1]]
messages = template.mm_plugin.process_messages(messages, images, processor)
kl_messages = template.mm_plugin.process_messages(kl_messages, images, processor)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
@ -60,8 +64,8 @@ def _encode_feedback_example(
response_ids += [tokenizer.eos_token_id] response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id] kl_response_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor) prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor) kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, tokenizer, processor)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len) source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
@ -74,8 +78,15 @@ def _encode_feedback_example(
labels = [IGNORE_INDEX] * source_len + response_ids labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
extra_inputs = template.mm_plugin.get_mm_inputs(
return input_ids, labels, kl_input_ids, kl_labels, kto_tag images=images,
feature_seqlens={
"token_type_ids": len(input_ids),
"kl_token_type_ids": len(kl_input_ids),
},
processor=processor,
)
return input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs
def preprocess_feedback_dataset( def preprocess_feedback_dataset(
@ -93,13 +104,13 @@ def preprocess_feedback_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs = _encode_feedback_example(
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( prompt=examples["prompt"][i],
prompt=prompt,
response=examples["response"][i], response=examples["response"][i],
kl_response=kl_response[i], kl_response=kl_response[i],
system=examples["system"][i], system=examples["system"][i],
tools=examples["tools"][i], tools=examples["tools"][i],
images=examples["images"][i],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -112,15 +123,8 @@ def preprocess_feedback_dataset(
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels) model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag) model_inputs["kto_tags"].append(kto_tag)
template.mm_plugin.process_model_inputs( for key, value in extra_inputs.items():
model_inputs=model_inputs, model_inputs[key].append(value)
images=examples["images"][i],
feature_seqlens={
"token_type_ids": len(input_ids),
"kl_token_type_ids": len(kl_input_ids),
},
processor=processor,
)
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num undesirable_num = len(model_inputs["kto_tags"]) - desirable_num

View File

@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
@ -35,13 +36,14 @@ def _encode_pairwise_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["Image"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]: ) -> Tuple[List[int], List[int], List[int], List[int], Dict[str, Any]]:
chosen_messages = prompt + [response[0]] chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor)
rejected_messages = prompt + [response[1]] rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor)
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools) _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
@ -49,7 +51,7 @@ def _encode_pairwise_example(
chosen_ids += [tokenizer.eos_token_id] chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id] rejected_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor) prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
# consider the response is more important # consider the response is more important
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len) source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
@ -60,8 +62,15 @@ def _encode_pairwise_example(
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
extra_inputs = template.mm_plugin.get_mm_inputs(
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels images=images,
feature_seqlens={
"chosen_token_type_ids": len(chosen_input_ids),
"rejected_token_type_ids": len(rejected_input_ids),
},
processor=processor,
)
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs
def preprocess_pairwise_dataset( def preprocess_pairwise_dataset(
@ -78,12 +87,12 @@ def preprocess_pairwise_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs = _encode_pairwise_example(
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( prompt=examples["prompt"][i],
prompt=prompt,
response=examples["response"][i], response=examples["response"][i],
system=examples["system"][i], system=examples["system"][i],
tools=examples["tools"][i], tools=examples["tools"][i],
images=examples["images"][i],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -95,15 +104,8 @@ def preprocess_pairwise_dataset(
model_inputs["rejected_input_ids"].append(rejected_input_ids) model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels) model_inputs["rejected_labels"].append(rejected_labels)
template.mm_plugin.process_model_inputs( for key, value in extra_inputs.items():
model_inputs=model_inputs, model_inputs[key].append(value)
images=examples["images"][i],
feature_seqlens={
"chosen_token_type_ids": len(chosen_input_ids),
"rejected_token_type_ids": len(rejected_input_ids),
},
processor=processor,
)
return model_inputs return model_inputs

View File

@ -21,6 +21,7 @@ from .processor_utils import greedy_knapsack, infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
@ -35,19 +36,18 @@ def _encode_supervised_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["Image"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
train_on_prompt: bool, train_on_prompt: bool,
mask_history: bool, mask_history: bool,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int], Dict[str, Any]]:
messages = prompt + response messages = template.mm_plugin.process_messages(prompt + response, images, processor)
input_ids, labels = [], [] input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor)
input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = 1 if template.efficient_eos else 0 total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history: if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns encoded_pairs = encoded_pairs[::-1] # high priority for last turns
@ -83,7 +83,10 @@ def _encode_supervised_example(
input_ids += [tokenizer.eos_token_id] input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
return input_ids, labels extra_inputs = template.mm_plugin.get_mm_inputs(
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
)
return input_ids, labels, extra_inputs
def preprocess_supervised_dataset( def preprocess_supervised_dataset(
@ -101,12 +104,12 @@ def preprocess_supervised_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) input_ids, labels, extra_inputs = _encode_supervised_example(
input_ids, labels = _encode_supervised_example( prompt=examples["prompt"][i],
prompt=prompt,
response=examples["response"][i], response=examples["response"][i],
system=examples["system"][i], system=examples["system"][i],
tools=examples["tools"][i], tools=examples["tools"][i],
images=examples["images"][i],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -117,12 +120,8 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
template.mm_plugin.process_model_inputs( for key, value in extra_inputs.items():
model_inputs=model_inputs, model_inputs[key].append(value)
images=examples["images"][i],
feature_seqlens={"token_type_ids": len(input_ids)},
processor=processor,
)
return model_inputs return model_inputs
@ -131,10 +130,15 @@ def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
if processor is not None:
raise NotImplementedError("`packing` have not been implemented for multimodal datasets.")
valid_num = 0 valid_num = 0
batch_input_ids, batch_labels = [], [] batch_input_ids, batch_labels = [], []
lengths = [] lengths = []
@ -149,6 +153,7 @@ def preprocess_packed_supervised_dataset(
response=examples["response"][i], response=examples["response"][i],
system=examples["system"][i], system=examples["system"][i],
tools=examples["tools"][i], tools=examples["tools"][i],
images=examples["images"][i],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=None, processor=None,

View File

@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
@ -35,25 +36,30 @@ def _encode_unsupervised_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["Image"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int], Dict[str, Any]]:
if len(response) == 1: if len(response) == 1:
messages = prompt + response messages = prompt + response
else: else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos: if template.efficient_eos:
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, tokenizer, processor) input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, tokenizer, processor)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len) source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len] input_ids = input_ids[:source_len]
labels = labels[:target_len] labels = labels[:target_len]
return input_ids, labels extra_inputs = template.mm_plugin.get_mm_inputs(
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
)
return input_ids, labels, extra_inputs
def preprocess_unsupervised_dataset( def preprocess_unsupervised_dataset(
@ -70,12 +76,12 @@ def preprocess_unsupervised_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) input_ids, labels, extra_inputs = _encode_unsupervised_example(
input_ids, labels = _encode_unsupervised_example( prompt=examples["prompt"][i],
prompt=prompt,
response=examples["response"][i], response=examples["response"][i],
system=examples["system"][i], system=examples["system"][i],
tools=examples["tools"][i], tools=examples["tools"][i],
images=examples["images"][i],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -84,12 +90,8 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
template.mm_plugin.process_model_inputs( for key, value in extra_inputs.items():
model_inputs=model_inputs, model_inputs[key].append(value)
images=examples["images"][i],
feature_seqlens={"token_type_ids": len(input_ids)},
processor=processor,
)
return model_inputs return model_inputs

View File

@ -15,6 +15,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .data_utils import Role from .data_utils import Role
@ -347,6 +349,11 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None, name: Optional[str] = None,
tool_format: Optional[str] = None, tool_format: Optional[str] = None,
) -> Template: ) -> Template:
if name == "qwen2_vl":
require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
if name is None: if name is None:
template = TEMPLATES["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
else: else:
@ -357,8 +364,8 @@ def get_template_and_fix_tokenizer(
if tool_format is not None: if tool_format is not None:
logger.info("Using tool format: {}.".format(tool_format)) logger.info("Using tool format: {}.".format(tool_format))
eos_slots = [] if template.efficient_eos else [{"eos_token"}] eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_tools = ToolFormatter(tool_format=tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format) template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
template.format_tools = ToolFormatter(tool_format=tool_format)
stop_words = template.stop_words stop_words = template.stop_words
if template.replace_eos: if template.replace_eos:

View File

@ -138,3 +138,17 @@ class GLM4ToolUtils(ToolUtils):
return content return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))] return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
}
def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None)
if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name))
return tool_utils

View File

@ -195,6 +195,9 @@ def is_gpu_or_npu_available() -> bool:
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r"""
Casts a torch tensor or a numpy array to a numpy array.
"""
if isinstance(inputs, torch.Tensor): if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu() inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4 if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
@ -206,6 +209,9 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def skip_check_imports() -> None: def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports transformers.dynamic_module_utils.check_imports = get_relative_imports

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from llamafactory.train.tuner import run_exp from llamafactory.train.tuner import run_exp # use absolute import
def launch(): def launch():

View File

@ -25,6 +25,7 @@ from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params from .model_utils.valuehead import load_valuehead_params
from .model_utils.visual import get_image_seqlen
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
@ -65,6 +66,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
@ -96,6 +98,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
try: try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer) setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
except Exception: except Exception:
processor = None processor = None

View File

@ -82,7 +82,7 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None: def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r""" r"""
Casts projector output to half precision for quantized VLMs. Casts projector output to half precision for fine-tuning quantized VLMs.
""" """
def _mm_projector_forward_post_hook( def _mm_projector_forward_post_hook(
@ -136,6 +136,22 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
return forbidden_modules return forbidden_modules
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""
Computes the number of special tokens per image.
"""
if getattr(config, "model_type", None) == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif getattr(config, "model_type", None) == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
elif getattr(config, "model_type", None) == "qwen2_vl": # variable length
image_seqlen = -1
return image_seqlen
def patch_target_modules( def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:

View File

@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import CustomDataCollatorForSeq2Seq, get_dataset from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint from ..callbacks import fix_valuehead_checkpoint
@ -45,7 +45,7 @@ def run_ppo(
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = CustomDataCollatorForSeq2Seq(tokenizer=tokenizer) data_collator = MultiModalDataCollatorForSeq2Seq(tokenizer=tokenizer)
# Create reference model and reward model # Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)

View File

@ -13,8 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, Tuple
from typing import TYPE_CHECKING, Any, Dict
import pytest import pytest
import torch import torch
@ -26,7 +25,7 @@ from llamafactory.model import load_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
@ -34,13 +33,20 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
MESSAGES = [ MM_MESSAGES = [
{"role": "user", "content": "<image>What is in this image?"}, {"role": "user", "content": "<image>What is in this image?"},
{"role": "assistant", "content": "A cat."}, {"role": "assistant", "content": "A cat."},
] ]
TEXT_MESSAGES = [
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "I am fine!"},
]
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))] IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
NO_IMAGES = []
INPUT_IDS = [0, 1, 2, 3, 4] INPUT_IDS = [0, 1, 2, 3, 4]
LABELS = [0, 1, 2, 3, 4] LABELS = [0, 1, 2, 3, 4]
@ -53,99 +59,110 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
return image_processor(images=IMAGES, return_tensors="pt") return image_processor(images=IMAGES, return_tensors="pt")
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]): def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
assert batch_a.keys() == batch_b.keys() assert batch_a.keys() == batch_b.keys()
for key in batch_a.keys(): for key in batch_a.keys():
if isinstance(batch_a[key], list): if isinstance(batch_a[key], torch.Tensor):
assert len(batch_a[key]) == len(batch_b[key])
for i in range(len(batch_a[key])):
if isinstance(batch_a[key][i], torch.Tensor):
assert torch.allclose(batch_a[key][i], batch_b[key][i], rtol=1e-4, atol=1e-5)
else:
assert batch_a[key][i] == batch_b[key][i]
elif isinstance(batch_a[key], torch.Tensor):
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5) assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
else: else:
raise NotImplementedError assert batch_a[key] == batch_b[key]
def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenizer", "ProcessorMixin"]:
model_args = ModelArguments(model_name_or_path=model_name_or_path)
tokenizer_module = load_tokenizer(model_args)
return tokenizer_module["tokenizer"], tokenizer_module["processor"]
def test_base_plugin(): def test_base_plugin():
model_args = ModelArguments(model_name_or_path=TINY_LLAMA) tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
base_plugin = get_mm_plugin(name="base", image_token="<image>") base_plugin = get_mm_plugin(name="base", image_token="<image>")
model_inputs = defaultdict(list) # test mm_messages
base_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
assert base_plugin.process_messages(MESSAGES, IMAGES, processor)
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {}) _is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {})
_is_close(model_inputs, {}) # test text_messages
assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {})
def test_llava_plugin(): def test_llava_plugin():
model_args = ModelArguments(model_name_or_path="llava-hf/llava-1.5-7b-hf") tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
tokenizer_module = load_tokenizer(model_args) image_seqlen = 576
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
mm_inputs = _get_mm_inputs(processor) mm_inputs = _get_mm_inputs(processor)
expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()} expected_mm_messages = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
llava_plugin = get_mm_plugin(name="llava", image_token="<image>") llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
model_inputs = defaultdict(list) # test mm_messages
llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
assert llava_plugin.process_messages(MESSAGES, IMAGES, processor)
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) _is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
_is_close(model_inputs, expected_model_inputs) # test text_messages
assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {"pixel_values": None})
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin(): def test_paligemma_plugin():
model_args = ModelArguments(model_name_or_path="google/paligemma-3b-pt-224") tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
tokenizer_module = load_tokenizer(model_args) image_seqlen = 256
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_seq_length: int = getattr(image_processor, "image_seq_length")
mm_inputs = _get_mm_inputs(processor) mm_inputs = _get_mm_inputs(processor)
mm_inputs["token_type_ids"] = [[0] * image_seq_length + [1] * (1024 - image_seq_length)] mm_inputs["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()} expected_mm_messages = [
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seq_length + INPUT_IDS {key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
expected_labels = [-100] * image_seq_length + LABELS ]
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
expected_labels = [-100] * image_seqlen + LABELS
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>") paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
model_inputs = defaultdict(list) # test mm_messages
paligemma_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) assert paligemma_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (
assert paligemma_plugin.process_messages(MESSAGES, IMAGES, processor)
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (
expected_input_ids, expected_input_ids,
expected_labels, expected_labels,
) )
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) _is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
_is_close(model_inputs, expected_model_inputs) # test text_messages
assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
INPUT_IDS,
LABELS,
)
_is_close(
paligemma_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
{"pixel_values": None, "token_type_ids": [[1] * 1024]},
)
def test_qwen2_vl_plugin(): def test_qwen2_vl_plugin():
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
tokenizer_module = load_tokenizer(model_args) image_seqlen = 4
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
mm_inputs = _get_mm_inputs(processor) mm_inputs = _get_mm_inputs(processor)
expected_model_inputs = {key: [value] for key, value in mm_inputs.items()} expected_mm_messages = [
{
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
for key, value in message.items()
}
for message in MM_MESSAGES
]
llava_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>") qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
model_inputs = defaultdict(list) # test mm_messages
llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor) assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
assert llava_plugin.process_messages(MESSAGES, IMAGES, processor) _is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS) # test text_messages
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs) assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
_is_close(model_inputs, expected_model_inputs) assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
_is_close(
qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
{"pixel_values": None, "image_grid_thw": None},
)