mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 11:42:49 +08:00
fix mixed mm inputs and rlhf-v
Former-commit-id: 9967ccb3aef3ca557ad6eafb78c6c99866857008
This commit is contained in:
parent
34dc36462c
commit
cb776752f6
@ -33,7 +33,7 @@ Dependency graph:
|
|||||||
transformers>=4.41.2,<=4.44.3
|
transformers>=4.41.2,<=4.44.3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .cli import VERSION
|
from .extras.env import VERSION
|
||||||
|
|
||||||
|
|
||||||
__version__ = VERSION
|
__version__ = VERSION
|
||||||
|
@ -89,11 +89,9 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
system = system or generating_args["default_system"]
|
system = system or generating_args["default_system"]
|
||||||
prompt_ids, _ = template.encode_oneturn(
|
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
||||||
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
|
||||||
)
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
|
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, [image], tokenizer, processor)
|
||||||
|
|
||||||
prompt_length = len(prompt_ids)
|
prompt_length = len(prompt_ids)
|
||||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||||
|
@ -124,9 +124,7 @@ class VllmEngine(BaseEngine):
|
|||||||
|
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
system = system or self.generating_args["default_system"]
|
system = system or self.generating_args["default_system"]
|
||||||
prompt_ids, _ = self.template.encode_oneturn(
|
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.processor is not None and image is not None: # add image features
|
if self.processor is not None and image is not None: # add image features
|
||||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||||
|
@ -13,8 +13,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .collator import (
|
from .collator import (
|
||||||
CustomDataCollatorForSeq2Seq,
|
|
||||||
KTODataCollatorWithPadding,
|
KTODataCollatorWithPadding,
|
||||||
|
MultiModalDataCollatorForSeq2Seq,
|
||||||
PairwiseDataCollatorWithPadding,
|
PairwiseDataCollatorWithPadding,
|
||||||
SFTDataCollatorWith4DAttentionMask,
|
SFTDataCollatorWith4DAttentionMask,
|
||||||
)
|
)
|
||||||
@ -24,8 +24,8 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
|||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CustomDataCollatorForSeq2Seq",
|
|
||||||
"KTODataCollatorWithPadding",
|
"KTODataCollatorWithPadding",
|
||||||
|
"MultiModalDataCollatorForSeq2Seq",
|
||||||
"PairwiseDataCollatorWithPadding",
|
"PairwiseDataCollatorWithPadding",
|
||||||
"SFTDataCollatorWith4DAttentionMask",
|
"SFTDataCollatorWith4DAttentionMask",
|
||||||
"Role",
|
"Role",
|
||||||
|
@ -62,44 +62,49 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
r"""
|
r"""
|
||||||
Data collator for custom models (like Qwen2-VL).
|
Data collator that supports VLMs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||||
image_grid_thw = None # TODO: better handle various VLMs
|
if "token_type_ids" in features[0].keys():
|
||||||
if "image_grid_thw" in features[0]:
|
for feature in features:
|
||||||
image_grid_thw_list = [
|
feature["token_type_ids"] = feature["token_type_ids"][0]
|
||||||
torch.Tensor(feature["image_grid_thw"]).long()
|
|
||||||
for feature in features
|
|
||||||
if feature["image_grid_thw"][0][0] > 0
|
|
||||||
]
|
|
||||||
pixel_values_list = [
|
|
||||||
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
|
|
||||||
]
|
|
||||||
if image_grid_thw_list:
|
|
||||||
image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
|
|
||||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
|
||||||
else:
|
|
||||||
image_grid_thw = None
|
|
||||||
pixel_values = None
|
|
||||||
|
|
||||||
features = [
|
extra_features = {}
|
||||||
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
|
if "pixel_values" in features[0].keys():
|
||||||
for feature in features
|
pixel_values = []
|
||||||
]
|
for feature in features:
|
||||||
|
if feature["pixel_values"] is None:
|
||||||
|
pixel_values.append(torch.zeros(0, dtype=torch.float))
|
||||||
|
else:
|
||||||
|
pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float))
|
||||||
|
|
||||||
features = super().__call__(features)
|
extra_features["pixel_values"] = torch.cat(pixel_values, dim=0)
|
||||||
if image_grid_thw is not None:
|
if extra_features["pixel_values"].numel() == 0:
|
||||||
features["image_grid_thw"] = image_grid_thw
|
extra_features["pixel_values"] = None
|
||||||
features["pixel_values"] = pixel_values
|
|
||||||
|
|
||||||
|
if "image_grid_thw" in features[0].keys():
|
||||||
|
image_grid_thw = []
|
||||||
|
for feature in features:
|
||||||
|
if feature["image_grid_thw"] is None:
|
||||||
|
image_grid_thw.append(torch.zeros(0, dtype=torch.long))
|
||||||
|
else:
|
||||||
|
image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long))
|
||||||
|
|
||||||
|
extra_features["image_grid_thw"] = torch.cat(pixel_values, dim=0)
|
||||||
|
if extra_features["image_grid_thw"].numel() == 0:
|
||||||
|
extra_features["image_grid_thw"] = None
|
||||||
|
|
||||||
|
features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features]
|
||||||
|
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||||
|
features.update({key: value for key, value in extra_features.items() if value is not None})
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
|
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||||
r"""
|
r"""
|
||||||
Data collator for 4d attention mask.
|
Data collator for 4d attention mask.
|
||||||
"""
|
"""
|
||||||
@ -117,7 +122,7 @@ class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||||
r"""
|
r"""
|
||||||
Data collator for pairwise data.
|
Data collator for pairwise data.
|
||||||
"""
|
"""
|
||||||
@ -152,7 +157,7 @@ class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||||
r"""
|
r"""
|
||||||
Data collator for KTO data.
|
Data collator for KTO data.
|
||||||
"""
|
"""
|
||||||
|
@ -16,16 +16,16 @@ import json
|
|||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Literal, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from .data_utils import SLOTS
|
from .data_utils import SLOTS
|
||||||
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
|
from .tool_utils import get_tool_utils
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Formatter(ABC):
|
class Formatter(ABC):
|
||||||
slots: SLOTS = field(default_factory=list)
|
slots: SLOTS = field(default_factory=list)
|
||||||
tool_format: Optional[Literal["default", "glm4"]] = None
|
tool_format: Optional[str] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, **kwargs) -> SLOTS: ...
|
def apply(self, **kwargs) -> SLOTS: ...
|
||||||
@ -81,12 +81,7 @@ class StringFormatter(Formatter):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class FunctionFormatter(Formatter):
|
class FunctionFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tool_format == "default":
|
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
||||||
self.slots = DefaultToolUtils.get_function_slots() + self.slots
|
|
||||||
elif self.tool_format == "glm4":
|
|
||||||
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
@ -119,22 +114,15 @@ class FunctionFormatter(Formatter):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ToolFormatter(Formatter):
|
class ToolFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tool_format == "default":
|
self.tool_utils = get_tool_utils(self.tool_format)
|
||||||
self._tool_formatter = DefaultToolUtils.tool_formatter
|
|
||||||
self._tool_extractor = DefaultToolUtils.tool_extractor
|
|
||||||
elif self.tool_format == "glm4":
|
|
||||||
self._tool_formatter = GLM4ToolUtils.tool_formatter
|
|
||||||
self._tool_extractor = GLM4ToolUtils.tool_extractor
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
try:
|
try:
|
||||||
tools = json.loads(content)
|
tools = json.loads(content)
|
||||||
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return [""]
|
return [""]
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
return self._tool_extractor(content)
|
return self.tool_utils.tool_extractor(content)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
@ -27,32 +28,33 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
|
|||||||
|
|
||||||
Returns: (qwen2-vl)
|
Returns: (qwen2-vl)
|
||||||
pixel_values: tensor with shape (num_patches, patch_dim)
|
pixel_values: tensor with shape (num_patches, patch_dim)
|
||||||
image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height
|
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
||||||
|
|
||||||
It holds num_patches == torch.prod(image_grid_thw)
|
It holds num_patches == torch.prod(image_grid_thw)
|
||||||
"""
|
"""
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
image_inputs = image_processor(images=images, return_tensors="pt")
|
image_inputs = image_processor(images=images, return_tensors="pt")
|
||||||
else:
|
else: # add NoneType for fake images
|
||||||
image = Image.new("RGB", (56, 56), (255, 255, 255))
|
image = Image.new("RGB", (64, 64), (255, 255, 255))
|
||||||
image_inputs = image_processor(images=[image], return_tensors="pt")
|
image_inputs = image_processor(images=[image], return_tensors="pt")
|
||||||
if "image_grid_thw" in image_inputs: # fake image for qwen2-vl
|
image_inputs = {key: None for key in image_inputs.keys()}
|
||||||
image_inputs["image_grid_thw"][0][0] = 0
|
|
||||||
|
|
||||||
return image_inputs
|
return image_inputs
|
||||||
|
|
||||||
|
|
||||||
def _get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
|
def _get_paligemma_token_type_ids(
|
||||||
|
images: Sequence["ImageObject"], input_len: int, processor: "ProcessorMixin"
|
||||||
|
) -> List[List[int]]:
|
||||||
r"""
|
r"""
|
||||||
Gets paligemma token type ids for computing loss.
|
Gets paligemma token type ids for computing loss.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
token_type_ids: shape (1, seq_len)
|
token_type_ids: shape (1, seq_len)
|
||||||
"""
|
"""
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
num_images = len(images)
|
||||||
image_seq_length: int = getattr(image_processor, "image_seq_length")
|
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
||||||
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
|
return [[0] * image_seqlen + [1] * (input_len - image_seqlen)]
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin:
|
class BasePlugin:
|
||||||
@ -74,6 +76,7 @@ class BasePlugin:
|
|||||||
self,
|
self,
|
||||||
input_ids: List[int],
|
input_ids: List[int],
|
||||||
labels: Optional[List[int]],
|
labels: Optional[List[int]],
|
||||||
|
images: Sequence["ImageObject"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Tuple[List[int], Optional[List[int]]]:
|
) -> Tuple[List[int], Optional[List[int]]]:
|
||||||
@ -93,18 +96,6 @@ class BasePlugin:
|
|||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def process_model_inputs(
|
|
||||||
self,
|
|
||||||
model_inputs: Dict[str, List[Any]],
|
|
||||||
images: Sequence["ImageObject"],
|
|
||||||
feature_seqlens: Dict[str, int],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> None:
|
|
||||||
r"""
|
|
||||||
Appends multimodal inputs to model inputs for VLMs.
|
|
||||||
"""
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaPlugin(BasePlugin):
|
class LlavaPlugin(BasePlugin):
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -113,21 +104,21 @@ class LlavaPlugin(BasePlugin):
|
|||||||
images: Sequence["ImageObject"],
|
images: Sequence["ImageObject"],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
image_count = 0
|
num_images = 0
|
||||||
new_messages = []
|
image_seqlen = getattr(processor, "image_seqlen")
|
||||||
|
messages = deepcopy(messages)
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
image_count += 1
|
num_images += 1
|
||||||
if image_count > 1:
|
|
||||||
raise ValueError("Llava model only accepts one image per sample.")
|
|
||||||
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||||
|
|
||||||
content = content.replace("{{image}}", self.image_token)
|
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||||
new_messages.append({"role": message["role"], "content": content})
|
|
||||||
|
|
||||||
return new_messages
|
if len(images) != num_images:
|
||||||
|
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
def get_mm_inputs(
|
def get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
@ -137,17 +128,6 @@ class LlavaPlugin(BasePlugin):
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return _get_mm_inputs(images, processor)
|
return _get_mm_inputs(images, processor)
|
||||||
|
|
||||||
def process_model_inputs(
|
|
||||||
self,
|
|
||||||
model_inputs: Dict[str, List[Any]],
|
|
||||||
images: Sequence["ImageObject"],
|
|
||||||
feature_seqlens: Dict[str, int],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> None:
|
|
||||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
|
||||||
for key, value in mm_inputs.items():
|
|
||||||
model_inputs[key].append(value[0])
|
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaPlugin(BasePlugin):
|
class PaliGemmaPlugin(BasePlugin):
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -156,34 +136,35 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
images: Sequence["ImageObject"],
|
images: Sequence["ImageObject"],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
image_count = 0
|
num_images = 0
|
||||||
new_messages = []
|
messages = deepcopy(messages)
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
image_count += 1
|
num_images += 1
|
||||||
if image_count > 1:
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||||
raise ValueError("PaliGemma model only accepts one image per sample.")
|
|
||||||
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
|
message["content"] = content.replace("{{image}}", "")
|
||||||
|
|
||||||
new_messages.append({"role": message["role"], "content": content})
|
if len(images) != num_images:
|
||||||
|
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||||
|
|
||||||
return new_messages
|
return messages
|
||||||
|
|
||||||
def process_token_ids(
|
def process_token_ids(
|
||||||
self,
|
self,
|
||||||
input_ids: List[int],
|
input_ids: List[int],
|
||||||
labels: Optional[List[int]],
|
labels: Optional[List[int]],
|
||||||
|
images: Sequence["ImageObject"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Tuple[List[int], Optional[List[int]]]:
|
) -> Tuple[List[int], Optional[List[int]]]:
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
num_images = len(images)
|
||||||
image_seq_length: int = getattr(image_processor, "image_seq_length")
|
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||||
input_ids = [image_token_id] * image_seq_length + input_ids
|
input_ids = [image_token_id] * image_seqlen + input_ids
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = [IGNORE_INDEX] * image_seq_length + labels
|
labels = [IGNORE_INDEX] * image_seqlen + labels
|
||||||
|
|
||||||
return input_ids, labels
|
return input_ids, labels
|
||||||
|
|
||||||
@ -195,21 +176,10 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
mm_inputs = _get_mm_inputs(images, processor)
|
mm_inputs = _get_mm_inputs(images, processor)
|
||||||
for feature_name, feature_length in feature_seqlens.items():
|
for feature_name, feature_length in feature_seqlens.items():
|
||||||
mm_inputs[feature_name] = _get_paligemma_token_type_ids(feature_length, processor)
|
mm_inputs[feature_name] = _get_paligemma_token_type_ids(images, feature_length, processor)
|
||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
def process_model_inputs(
|
|
||||||
self,
|
|
||||||
model_inputs: Dict[str, List[Any]],
|
|
||||||
images: Sequence["ImageObject"],
|
|
||||||
feature_seqlens: Dict[str, int],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> None:
|
|
||||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
|
||||||
for key, value in mm_inputs.items():
|
|
||||||
model_inputs[key].append(value[0])
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2vlPlugin(BasePlugin):
|
class Qwen2vlPlugin(BasePlugin):
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -223,23 +193,26 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
if len(images) > 0:
|
if len(images) > 0:
|
||||||
image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"]
|
image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"]
|
||||||
|
|
||||||
index = 0
|
num_images = 0
|
||||||
new_messages = []
|
messages = deepcopy(messages)
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
IMAGE_PLACEHOLDER,
|
IMAGE_PLACEHOLDER,
|
||||||
"<|vision_start|>{}<|vision_end|>".format(
|
"<|vision_start|>{}<|vision_end|>".format(
|
||||||
self.image_token * (image_grid_thw[index].prod() // merge_length)
|
self.image_token * (image_grid_thw[num_images].prod() // merge_length)
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
index += 1
|
num_images += 1
|
||||||
|
|
||||||
new_messages.append({"role": message["role"], "content": content})
|
message["content"] = content
|
||||||
|
|
||||||
return new_messages
|
if len(images) != num_images:
|
||||||
|
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
def get_mm_inputs(
|
def get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
@ -249,17 +222,6 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return _get_mm_inputs(images, processor)
|
return _get_mm_inputs(images, processor)
|
||||||
|
|
||||||
def process_model_inputs(
|
|
||||||
self,
|
|
||||||
model_inputs: Dict[str, List[Any]],
|
|
||||||
images: Sequence["ImageObject"],
|
|
||||||
feature_seqlens: Dict[str, int],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> None:
|
|
||||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
|
||||||
for key, value in mm_inputs.items():
|
|
||||||
model_inputs[key].append(value) # support multi-image
|
|
||||||
|
|
||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
@ -270,7 +232,8 @@ PLUGINS = {
|
|||||||
|
|
||||||
|
|
||||||
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
|
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
|
||||||
if name not in PLUGINS:
|
plugin_class = PLUGINS.get(name, None)
|
||||||
raise ValueError("{} not found.".format(name))
|
if plugin_class is None:
|
||||||
|
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
||||||
|
|
||||||
return PLUGINS[name](image_token)
|
return plugin_class(image_token)
|
||||||
|
@ -50,7 +50,7 @@ def get_preprocess_and_print_func(
|
|||||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
elif stage == "sft" and not do_generate:
|
elif stage == "sft" and not do_generate:
|
||||||
if data_args.packing:
|
if data_args.packing:
|
||||||
if data_args.neat_packing:
|
if data_args.neat_packing: # hack datasets to have int32 attention mask
|
||||||
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
||||||
|
|
||||||
def __init__(self, data, **kwargs):
|
def __init__(self, data, **kwargs):
|
||||||
@ -67,6 +67,7 @@ def get_preprocess_and_print_func(
|
|||||||
preprocess_packed_supervised_dataset,
|
preprocess_packed_supervised_dataset,
|
||||||
template=template,
|
template=template,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
processor=processor,
|
||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from PIL.Image import Image
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
@ -36,11 +37,12 @@ def _encode_feedback_example(
|
|||||||
kl_response: Sequence[Dict[str, str]],
|
kl_response: Sequence[Dict[str, str]],
|
||||||
system: Optional[str],
|
system: Optional[str],
|
||||||
tools: Optional[str],
|
tools: Optional[str],
|
||||||
|
images: Sequence["Image"],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
cutoff_len: int,
|
cutoff_len: int,
|
||||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
) -> Tuple[List[int], List[int], List[int], List[int], bool, Dict[str, Any]]:
|
||||||
if response[0]["content"]: # desired example
|
if response[0]["content"]: # desired example
|
||||||
kto_tag = True
|
kto_tag = True
|
||||||
messages = prompt + [response[0]]
|
messages = prompt + [response[0]]
|
||||||
@ -53,6 +55,8 @@ def _encode_feedback_example(
|
|||||||
else:
|
else:
|
||||||
kl_messages = prompt + [kl_response[1]]
|
kl_messages = prompt + [kl_response[1]]
|
||||||
|
|
||||||
|
messages = template.mm_plugin.process_messages(messages, images, processor)
|
||||||
|
kl_messages = template.mm_plugin.process_messages(kl_messages, images, processor)
|
||||||
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||||
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
||||||
|
|
||||||
@ -60,8 +64,8 @@ def _encode_feedback_example(
|
|||||||
response_ids += [tokenizer.eos_token_id]
|
response_ids += [tokenizer.eos_token_id]
|
||||||
kl_response_ids += [tokenizer.eos_token_id]
|
kl_response_ids += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
|
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
|
||||||
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor)
|
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, tokenizer, processor)
|
||||||
|
|
||||||
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
|
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
|
||||||
prompt_ids = prompt_ids[:source_len]
|
prompt_ids = prompt_ids[:source_len]
|
||||||
@ -74,8 +78,15 @@ def _encode_feedback_example(
|
|||||||
labels = [IGNORE_INDEX] * source_len + response_ids
|
labels = [IGNORE_INDEX] * source_len + response_ids
|
||||||
kl_input_ids = kl_prompt_ids + kl_response_ids
|
kl_input_ids = kl_prompt_ids + kl_response_ids
|
||||||
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
||||||
|
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
images=images,
|
||||||
|
feature_seqlens={
|
||||||
|
"token_type_ids": len(input_ids),
|
||||||
|
"kl_token_type_ids": len(kl_input_ids),
|
||||||
|
},
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
return input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs
|
||||||
|
|
||||||
|
|
||||||
def preprocess_feedback_dataset(
|
def preprocess_feedback_dataset(
|
||||||
@ -93,13 +104,13 @@ def preprocess_feedback_dataset(
|
|||||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs = _encode_feedback_example(
|
||||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
prompt=examples["prompt"][i],
|
||||||
prompt=prompt,
|
|
||||||
response=examples["response"][i],
|
response=examples["response"][i],
|
||||||
kl_response=kl_response[i],
|
kl_response=kl_response[i],
|
||||||
system=examples["system"][i],
|
system=examples["system"][i],
|
||||||
tools=examples["tools"][i],
|
tools=examples["tools"][i],
|
||||||
|
images=examples["images"][i],
|
||||||
template=template,
|
template=template,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
@ -112,15 +123,8 @@ def preprocess_feedback_dataset(
|
|||||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||||
model_inputs["kl_labels"].append(kl_labels)
|
model_inputs["kl_labels"].append(kl_labels)
|
||||||
model_inputs["kto_tags"].append(kto_tag)
|
model_inputs["kto_tags"].append(kto_tag)
|
||||||
template.mm_plugin.process_model_inputs(
|
for key, value in extra_inputs.items():
|
||||||
model_inputs=model_inputs,
|
model_inputs[key].append(value)
|
||||||
images=examples["images"][i],
|
|
||||||
feature_seqlens={
|
|
||||||
"token_type_ids": len(input_ids),
|
|
||||||
"kl_token_type_ids": len(kl_input_ids),
|
|
||||||
},
|
|
||||||
processor=processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||||
|
@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from PIL.Image import Image
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
@ -35,13 +36,14 @@ def _encode_pairwise_example(
|
|||||||
response: Sequence[Dict[str, str]],
|
response: Sequence[Dict[str, str]],
|
||||||
system: Optional[str],
|
system: Optional[str],
|
||||||
tools: Optional[str],
|
tools: Optional[str],
|
||||||
|
images: Sequence["Image"],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
cutoff_len: int,
|
cutoff_len: int,
|
||||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
) -> Tuple[List[int], List[int], List[int], List[int], Dict[str, Any]]:
|
||||||
chosen_messages = prompt + [response[0]]
|
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor)
|
||||||
rejected_messages = prompt + [response[1]]
|
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor)
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||||
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
||||||
|
|
||||||
@ -49,7 +51,7 @@ def _encode_pairwise_example(
|
|||||||
chosen_ids += [tokenizer.eos_token_id]
|
chosen_ids += [tokenizer.eos_token_id]
|
||||||
rejected_ids += [tokenizer.eos_token_id]
|
rejected_ids += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
|
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
|
||||||
# consider the response is more important
|
# consider the response is more important
|
||||||
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
|
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
|
||||||
prompt_ids = prompt_ids[:source_len]
|
prompt_ids = prompt_ids[:source_len]
|
||||||
@ -60,8 +62,15 @@ def _encode_pairwise_example(
|
|||||||
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
|
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
|
||||||
rejected_input_ids = prompt_ids + rejected_ids
|
rejected_input_ids = prompt_ids + rejected_ids
|
||||||
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
||||||
|
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
images=images,
|
||||||
|
feature_seqlens={
|
||||||
|
"chosen_token_type_ids": len(chosen_input_ids),
|
||||||
|
"rejected_token_type_ids": len(rejected_input_ids),
|
||||||
|
},
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs
|
||||||
|
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(
|
def preprocess_pairwise_dataset(
|
||||||
@ -78,12 +87,12 @@ def preprocess_pairwise_dataset(
|
|||||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs = _encode_pairwise_example(
|
||||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
prompt=examples["prompt"][i],
|
||||||
prompt=prompt,
|
|
||||||
response=examples["response"][i],
|
response=examples["response"][i],
|
||||||
system=examples["system"][i],
|
system=examples["system"][i],
|
||||||
tools=examples["tools"][i],
|
tools=examples["tools"][i],
|
||||||
|
images=examples["images"][i],
|
||||||
template=template,
|
template=template,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
@ -95,15 +104,8 @@ def preprocess_pairwise_dataset(
|
|||||||
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
||||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||||
model_inputs["rejected_labels"].append(rejected_labels)
|
model_inputs["rejected_labels"].append(rejected_labels)
|
||||||
template.mm_plugin.process_model_inputs(
|
for key, value in extra_inputs.items():
|
||||||
model_inputs=model_inputs,
|
model_inputs[key].append(value)
|
||||||
images=examples["images"][i],
|
|
||||||
feature_seqlens={
|
|
||||||
"chosen_token_type_ids": len(chosen_input_ids),
|
|
||||||
"rejected_token_type_ids": len(rejected_input_ids),
|
|
||||||
},
|
|
||||||
processor=processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ from .processor_utils import greedy_knapsack, infer_seqlen
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from PIL.Image import Image
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
@ -35,19 +36,18 @@ def _encode_supervised_example(
|
|||||||
response: Sequence[Dict[str, str]],
|
response: Sequence[Dict[str, str]],
|
||||||
system: Optional[str],
|
system: Optional[str],
|
||||||
tools: Optional[str],
|
tools: Optional[str],
|
||||||
|
images: Sequence["Image"],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
cutoff_len: int,
|
cutoff_len: int,
|
||||||
train_on_prompt: bool,
|
train_on_prompt: bool,
|
||||||
mask_history: bool,
|
mask_history: bool,
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int], Dict[str, Any]]:
|
||||||
messages = prompt + response
|
messages = template.mm_plugin.process_messages(prompt + response, images, processor)
|
||||||
input_ids, labels = [], []
|
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor)
|
||||||
input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor)
|
|
||||||
|
|
||||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||||
total_length = 1 if template.efficient_eos else 0
|
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
|
||||||
if mask_history:
|
if mask_history:
|
||||||
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
||||||
|
|
||||||
@ -83,7 +83,10 @@ def _encode_supervised_example(
|
|||||||
input_ids += [tokenizer.eos_token_id]
|
input_ids += [tokenizer.eos_token_id]
|
||||||
labels += [tokenizer.eos_token_id]
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
return input_ids, labels
|
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||||
|
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
|
||||||
|
)
|
||||||
|
return input_ids, labels, extra_inputs
|
||||||
|
|
||||||
|
|
||||||
def preprocess_supervised_dataset(
|
def preprocess_supervised_dataset(
|
||||||
@ -101,12 +104,12 @@ def preprocess_supervised_dataset(
|
|||||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
input_ids, labels, extra_inputs = _encode_supervised_example(
|
||||||
input_ids, labels = _encode_supervised_example(
|
prompt=examples["prompt"][i],
|
||||||
prompt=prompt,
|
|
||||||
response=examples["response"][i],
|
response=examples["response"][i],
|
||||||
system=examples["system"][i],
|
system=examples["system"][i],
|
||||||
tools=examples["tools"][i],
|
tools=examples["tools"][i],
|
||||||
|
images=examples["images"][i],
|
||||||
template=template,
|
template=template,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
@ -117,12 +120,8 @@ def preprocess_supervised_dataset(
|
|||||||
model_inputs["input_ids"].append(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
model_inputs["labels"].append(labels)
|
model_inputs["labels"].append(labels)
|
||||||
template.mm_plugin.process_model_inputs(
|
for key, value in extra_inputs.items():
|
||||||
model_inputs=model_inputs,
|
model_inputs[key].append(value)
|
||||||
images=examples["images"][i],
|
|
||||||
feature_seqlens={"token_type_ids": len(input_ids)},
|
|
||||||
processor=processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@ -131,10 +130,15 @@ def preprocess_packed_supervised_dataset(
|
|||||||
examples: Dict[str, List[Any]],
|
examples: Dict[str, List[Any]],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Dict[str, List[Any]]:
|
) -> Dict[str, List[Any]]:
|
||||||
|
# TODO: use `position_ids` to achieve packing
|
||||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||||
|
if processor is not None:
|
||||||
|
raise NotImplementedError("`packing` have not been implemented for multimodal datasets.")
|
||||||
|
|
||||||
valid_num = 0
|
valid_num = 0
|
||||||
batch_input_ids, batch_labels = [], []
|
batch_input_ids, batch_labels = [], []
|
||||||
lengths = []
|
lengths = []
|
||||||
@ -149,6 +153,7 @@ def preprocess_packed_supervised_dataset(
|
|||||||
response=examples["response"][i],
|
response=examples["response"][i],
|
||||||
system=examples["system"][i],
|
system=examples["system"][i],
|
||||||
tools=examples["tools"][i],
|
tools=examples["tools"][i],
|
||||||
|
images=examples["images"][i],
|
||||||
template=template,
|
template=template,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
processor=None,
|
processor=None,
|
||||||
|
@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from PIL.Image import Image
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
@ -35,25 +36,30 @@ def _encode_unsupervised_example(
|
|||||||
response: Sequence[Dict[str, str]],
|
response: Sequence[Dict[str, str]],
|
||||||
system: Optional[str],
|
system: Optional[str],
|
||||||
tools: Optional[str],
|
tools: Optional[str],
|
||||||
|
images: Sequence["Image"],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
cutoff_len: int,
|
cutoff_len: int,
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int], Dict[str, Any]]:
|
||||||
if len(response) == 1:
|
if len(response) == 1:
|
||||||
messages = prompt + response
|
messages = prompt + response
|
||||||
else:
|
else:
|
||||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
|
||||||
|
messages = template.mm_plugin.process_messages(messages, images, processor)
|
||||||
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
labels += [tokenizer.eos_token_id]
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, tokenizer, processor)
|
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, tokenizer, processor)
|
||||||
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
|
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
|
||||||
input_ids = input_ids[:source_len]
|
input_ids = input_ids[:source_len]
|
||||||
labels = labels[:target_len]
|
labels = labels[:target_len]
|
||||||
return input_ids, labels
|
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||||
|
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
|
||||||
|
)
|
||||||
|
return input_ids, labels, extra_inputs
|
||||||
|
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(
|
def preprocess_unsupervised_dataset(
|
||||||
@ -70,12 +76,12 @@ def preprocess_unsupervised_dataset(
|
|||||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
input_ids, labels, extra_inputs = _encode_unsupervised_example(
|
||||||
input_ids, labels = _encode_unsupervised_example(
|
prompt=examples["prompt"][i],
|
||||||
prompt=prompt,
|
|
||||||
response=examples["response"][i],
|
response=examples["response"][i],
|
||||||
system=examples["system"][i],
|
system=examples["system"][i],
|
||||||
tools=examples["tools"][i],
|
tools=examples["tools"][i],
|
||||||
|
images=examples["images"][i],
|
||||||
template=template,
|
template=template,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
@ -84,12 +90,8 @@ def preprocess_unsupervised_dataset(
|
|||||||
model_inputs["input_ids"].append(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
model_inputs["labels"].append(labels)
|
model_inputs["labels"].append(labels)
|
||||||
template.mm_plugin.process_model_inputs(
|
for key, value in extra_inputs.items():
|
||||||
model_inputs=model_inputs,
|
model_inputs[key].append(value)
|
||||||
images=examples["images"][i],
|
|
||||||
feature_seqlens={"token_type_ids": len(input_ids)},
|
|
||||||
processor=processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
@ -15,6 +15,8 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .data_utils import Role
|
from .data_utils import Role
|
||||||
@ -347,6 +349,11 @@ def get_template_and_fix_tokenizer(
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
tool_format: Optional[str] = None,
|
tool_format: Optional[str] = None,
|
||||||
) -> Template:
|
) -> Template:
|
||||||
|
if name == "qwen2_vl":
|
||||||
|
require_version(
|
||||||
|
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||||
|
)
|
||||||
|
|
||||||
if name is None:
|
if name is None:
|
||||||
template = TEMPLATES["empty"] # placeholder
|
template = TEMPLATES["empty"] # placeholder
|
||||||
else:
|
else:
|
||||||
@ -357,8 +364,8 @@ def get_template_and_fix_tokenizer(
|
|||||||
if tool_format is not None:
|
if tool_format is not None:
|
||||||
logger.info("Using tool format: {}.".format(tool_format))
|
logger.info("Using tool format: {}.".format(tool_format))
|
||||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||||
template.format_tools = ToolFormatter(tool_format=tool_format)
|
|
||||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
|
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
|
||||||
|
template.format_tools = ToolFormatter(tool_format=tool_format)
|
||||||
|
|
||||||
stop_words = template.stop_words
|
stop_words = template.stop_words
|
||||||
if template.replace_eos:
|
if template.replace_eos:
|
||||||
|
@ -138,3 +138,17 @@ class GLM4ToolUtils(ToolUtils):
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||||
|
|
||||||
|
|
||||||
|
TOOLS = {
|
||||||
|
"default": DefaultToolUtils(),
|
||||||
|
"glm4": GLM4ToolUtils(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_utils(name: str) -> "ToolUtils":
|
||||||
|
tool_utils = TOOLS.get(name, None)
|
||||||
|
if tool_utils is None:
|
||||||
|
raise ValueError("Tool utils `{}` not found.".format(name))
|
||||||
|
|
||||||
|
return tool_utils
|
||||||
|
@ -195,6 +195,9 @@ def is_gpu_or_npu_available() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||||
|
r"""
|
||||||
|
Casts a torch tensor or a numpy array to a numpy array.
|
||||||
|
"""
|
||||||
if isinstance(inputs, torch.Tensor):
|
if isinstance(inputs, torch.Tensor):
|
||||||
inputs = inputs.cpu()
|
inputs = inputs.cpu()
|
||||||
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
|
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
|
||||||
@ -206,6 +209,9 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
|||||||
|
|
||||||
|
|
||||||
def skip_check_imports() -> None:
|
def skip_check_imports() -> None:
|
||||||
|
r"""
|
||||||
|
Avoids flash attention import error in custom model files.
|
||||||
|
"""
|
||||||
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from llamafactory.train.tuner import run_exp
|
from llamafactory.train.tuner import run_exp # use absolute import
|
||||||
|
|
||||||
|
|
||||||
def launch():
|
def launch():
|
||||||
|
@ -25,6 +25,7 @@ from .model_utils.misc import register_autoclass
|
|||||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||||
from .model_utils.valuehead import load_valuehead_params
|
from .model_utils.valuehead import load_valuehead_params
|
||||||
|
from .model_utils.visual import get_image_seqlen
|
||||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||||
|
|
||||||
|
|
||||||
@ -65,6 +66,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||||||
Note: including inplace operation of model_args.
|
Note: including inplace operation of model_args.
|
||||||
"""
|
"""
|
||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
|
config = load_config(model_args)
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
@ -96,6 +98,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||||||
try:
|
try:
|
||||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
setattr(processor, "tokenizer", tokenizer)
|
setattr(processor, "tokenizer", tokenizer)
|
||||||
|
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||||
except Exception:
|
except Exception:
|
||||||
processor = None
|
processor = None
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
|
|||||||
|
|
||||||
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||||
r"""
|
r"""
|
||||||
Casts projector output to half precision for quantized VLMs.
|
Casts projector output to half precision for fine-tuning quantized VLMs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _mm_projector_forward_post_hook(
|
def _mm_projector_forward_post_hook(
|
||||||
@ -136,6 +136,22 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
|||||||
return forbidden_modules
|
return forbidden_modules
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||||
|
r"""
|
||||||
|
Computes the number of special tokens per image.
|
||||||
|
"""
|
||||||
|
if getattr(config, "model_type", None) == "llava":
|
||||||
|
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
|
||||||
|
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
|
||||||
|
image_seqlen += 1
|
||||||
|
elif getattr(config, "model_type", None) == "paligemma":
|
||||||
|
image_seqlen = config.vision_config.num_image_tokens
|
||||||
|
elif getattr(config, "model_type", None) == "qwen2_vl": # variable length
|
||||||
|
image_seqlen = -1
|
||||||
|
|
||||||
|
return image_seqlen
|
||||||
|
|
||||||
|
|
||||||
def patch_target_modules(
|
def patch_target_modules(
|
||||||
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, List[str]]:
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from ...data import CustomDataCollatorForSeq2Seq, get_dataset
|
from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..callbacks import fix_valuehead_checkpoint
|
from ..callbacks import fix_valuehead_checkpoint
|
||||||
@ -45,7 +45,7 @@ def run_ppo(
|
|||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||||
|
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||||
data_collator = CustomDataCollatorForSeq2Seq(tokenizer=tokenizer)
|
data_collator = MultiModalDataCollatorForSeq2Seq(tokenizer=tokenizer)
|
||||||
|
|
||||||
# Create reference model and reward model
|
# Create reference model and reward model
|
||||||
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
|
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
|
||||||
|
@ -13,8 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||||
from typing import TYPE_CHECKING, Any, Dict
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -26,7 +25,7 @@ from llamafactory.model import load_tokenizer
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
|
||||||
|
|
||||||
@ -34,13 +33,20 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
|||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
MESSAGES = [
|
MM_MESSAGES = [
|
||||||
{"role": "user", "content": "<image>What is in this image?"},
|
{"role": "user", "content": "<image>What is in this image?"},
|
||||||
{"role": "assistant", "content": "A cat."},
|
{"role": "assistant", "content": "A cat."},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
TEXT_MESSAGES = [
|
||||||
|
{"role": "user", "content": "How are you"},
|
||||||
|
{"role": "assistant", "content": "I am fine!"},
|
||||||
|
]
|
||||||
|
|
||||||
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
|
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
|
||||||
|
|
||||||
|
NO_IMAGES = []
|
||||||
|
|
||||||
INPUT_IDS = [0, 1, 2, 3, 4]
|
INPUT_IDS = [0, 1, 2, 3, 4]
|
||||||
|
|
||||||
LABELS = [0, 1, 2, 3, 4]
|
LABELS = [0, 1, 2, 3, 4]
|
||||||
@ -53,99 +59,110 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
|||||||
return image_processor(images=IMAGES, return_tensors="pt")
|
return image_processor(images=IMAGES, return_tensors="pt")
|
||||||
|
|
||||||
|
|
||||||
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]):
|
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
|
||||||
assert batch_a.keys() == batch_b.keys()
|
assert batch_a.keys() == batch_b.keys()
|
||||||
for key in batch_a.keys():
|
for key in batch_a.keys():
|
||||||
if isinstance(batch_a[key], list):
|
if isinstance(batch_a[key], torch.Tensor):
|
||||||
assert len(batch_a[key]) == len(batch_b[key])
|
|
||||||
for i in range(len(batch_a[key])):
|
|
||||||
if isinstance(batch_a[key][i], torch.Tensor):
|
|
||||||
assert torch.allclose(batch_a[key][i], batch_b[key][i], rtol=1e-4, atol=1e-5)
|
|
||||||
else:
|
|
||||||
assert batch_a[key][i] == batch_b[key][i]
|
|
||||||
elif isinstance(batch_a[key], torch.Tensor):
|
|
||||||
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
assert batch_a[key] == batch_b[key]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenizer", "ProcessorMixin"]:
|
||||||
|
model_args = ModelArguments(model_name_or_path=model_name_or_path)
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
return tokenizer_module["tokenizer"], tokenizer_module["processor"]
|
||||||
|
|
||||||
|
|
||||||
def test_base_plugin():
|
def test_base_plugin():
|
||||||
model_args = ModelArguments(model_name_or_path=TINY_LLAMA)
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
|
||||||
processor = tokenizer_module["processor"]
|
|
||||||
|
|
||||||
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
||||||
model_inputs = defaultdict(list)
|
# test mm_messages
|
||||||
base_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES
|
||||||
|
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||||
assert base_plugin.process_messages(MESSAGES, IMAGES, processor)
|
|
||||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
|
|
||||||
_is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {})
|
_is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {})
|
||||||
_is_close(model_inputs, {})
|
# test text_messages
|
||||||
|
assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||||
|
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||||
|
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {})
|
||||||
|
|
||||||
|
|
||||||
def test_llava_plugin():
|
def test_llava_plugin():
|
||||||
model_args = ModelArguments(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
image_seqlen = 576
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
|
||||||
processor = tokenizer_module["processor"]
|
|
||||||
|
|
||||||
mm_inputs = _get_mm_inputs(processor)
|
mm_inputs = _get_mm_inputs(processor)
|
||||||
expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()}
|
expected_mm_messages = [
|
||||||
|
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||||
|
for message in MM_MESSAGES
|
||||||
|
]
|
||||||
|
|
||||||
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
|
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
|
||||||
model_inputs = defaultdict(list)
|
# test mm_messages
|
||||||
llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||||
|
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||||
assert llava_plugin.process_messages(MESSAGES, IMAGES, processor)
|
|
||||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
|
|
||||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||||
_is_close(model_inputs, expected_model_inputs)
|
# test text_messages
|
||||||
|
assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||||
|
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||||
|
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {"pixel_values": None})
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||||
def test_paligemma_plugin():
|
def test_paligemma_plugin():
|
||||||
model_args = ModelArguments(model_name_or_path="google/paligemma-3b-pt-224")
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
image_seqlen = 256
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
|
||||||
processor = tokenizer_module["processor"]
|
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
|
||||||
image_seq_length: int = getattr(image_processor, "image_seq_length")
|
|
||||||
|
|
||||||
mm_inputs = _get_mm_inputs(processor)
|
mm_inputs = _get_mm_inputs(processor)
|
||||||
mm_inputs["token_type_ids"] = [[0] * image_seq_length + [1] * (1024 - image_seq_length)]
|
mm_inputs["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
|
||||||
expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()}
|
expected_mm_messages = [
|
||||||
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seq_length + INPUT_IDS
|
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
|
||||||
expected_labels = [-100] * image_seq_length + LABELS
|
]
|
||||||
|
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
|
||||||
|
expected_labels = [-100] * image_seqlen + LABELS
|
||||||
|
|
||||||
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
|
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
|
||||||
model_inputs = defaultdict(list)
|
# test mm_messages
|
||||||
paligemma_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
assert paligemma_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||||
|
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (
|
||||||
assert paligemma_plugin.process_messages(MESSAGES, IMAGES, processor)
|
|
||||||
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (
|
|
||||||
expected_input_ids,
|
expected_input_ids,
|
||||||
expected_labels,
|
expected_labels,
|
||||||
)
|
)
|
||||||
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||||
_is_close(model_inputs, expected_model_inputs)
|
# test text_messages
|
||||||
|
assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||||
|
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
|
||||||
|
INPUT_IDS,
|
||||||
|
LABELS,
|
||||||
|
)
|
||||||
|
_is_close(
|
||||||
|
paligemma_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
|
||||||
|
{"pixel_values": None, "token_type_ids": [[1] * 1024]},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_qwen2_vl_plugin():
|
def test_qwen2_vl_plugin():
|
||||||
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
image_seqlen = 4
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
|
||||||
processor = tokenizer_module["processor"]
|
|
||||||
|
|
||||||
mm_inputs = _get_mm_inputs(processor)
|
mm_inputs = _get_mm_inputs(processor)
|
||||||
expected_model_inputs = {key: [value] for key, value in mm_inputs.items()}
|
expected_mm_messages = [
|
||||||
|
{
|
||||||
|
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
|
||||||
|
for key, value in message.items()
|
||||||
|
}
|
||||||
|
for message in MM_MESSAGES
|
||||||
|
]
|
||||||
|
|
||||||
llava_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||||
model_inputs = defaultdict(list)
|
# test mm_messages
|
||||||
llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||||
|
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||||
assert llava_plugin.process_messages(MESSAGES, IMAGES, processor)
|
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
|
# test text_messages
|
||||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||||
_is_close(model_inputs, expected_model_inputs)
|
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||||
|
_is_close(
|
||||||
|
qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
|
||||||
|
{"pixel_values": None, "image_grid_thw": None},
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user