mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
fix mixed mm inputs and rlhf-v
Former-commit-id: 9967ccb3aef3ca557ad6eafb78c6c99866857008
This commit is contained in:
parent
34dc36462c
commit
cb776752f6
@ -33,7 +33,7 @@ Dependency graph:
|
||||
transformers>=4.41.2,<=4.44.3
|
||||
"""
|
||||
|
||||
from .cli import VERSION
|
||||
from .extras.env import VERSION
|
||||
|
||||
|
||||
__version__ = VERSION
|
||||
|
@ -89,11 +89,9 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or generating_args["default_system"]
|
||||
prompt_ids, _ = template.encode_oneturn(
|
||||
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
||||
if image is not None:
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, [image], tokenizer, processor)
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
|
@ -124,9 +124,7 @@ class VllmEngine(BaseEngine):
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
prompt_ids, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||
|
||||
if self.processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
|
@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .collator import (
|
||||
CustomDataCollatorForSeq2Seq,
|
||||
KTODataCollatorWithPadding,
|
||||
MultiModalDataCollatorForSeq2Seq,
|
||||
PairwiseDataCollatorWithPadding,
|
||||
SFTDataCollatorWith4DAttentionMask,
|
||||
)
|
||||
@ -24,8 +24,8 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CustomDataCollatorForSeq2Seq",
|
||||
"KTODataCollatorWithPadding",
|
||||
"MultiModalDataCollatorForSeq2Seq",
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"SFTDataCollatorWith4DAttentionMask",
|
||||
"Role",
|
||||
|
@ -62,44 +62,49 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for custom models (like Qwen2-VL).
|
||||
Data collator that supports VLMs.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
image_grid_thw = None # TODO: better handle various VLMs
|
||||
if "image_grid_thw" in features[0]:
|
||||
image_grid_thw_list = [
|
||||
torch.Tensor(feature["image_grid_thw"]).long()
|
||||
for feature in features
|
||||
if feature["image_grid_thw"][0][0] > 0
|
||||
]
|
||||
pixel_values_list = [
|
||||
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
|
||||
]
|
||||
if image_grid_thw_list:
|
||||
image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
else:
|
||||
image_grid_thw = None
|
||||
pixel_values = None
|
||||
if "token_type_ids" in features[0].keys():
|
||||
for feature in features:
|
||||
feature["token_type_ids"] = feature["token_type_ids"][0]
|
||||
|
||||
features = [
|
||||
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
|
||||
for feature in features
|
||||
]
|
||||
extra_features = {}
|
||||
if "pixel_values" in features[0].keys():
|
||||
pixel_values = []
|
||||
for feature in features:
|
||||
if feature["pixel_values"] is None:
|
||||
pixel_values.append(torch.zeros(0, dtype=torch.float))
|
||||
else:
|
||||
pixel_values.append(torch.tensor(feature["pixel_values"], dtype=torch.float))
|
||||
|
||||
features = super().__call__(features)
|
||||
if image_grid_thw is not None:
|
||||
features["image_grid_thw"] = image_grid_thw
|
||||
features["pixel_values"] = pixel_values
|
||||
extra_features["pixel_values"] = torch.cat(pixel_values, dim=0)
|
||||
if extra_features["pixel_values"].numel() == 0:
|
||||
extra_features["pixel_values"] = None
|
||||
|
||||
if "image_grid_thw" in features[0].keys():
|
||||
image_grid_thw = []
|
||||
for feature in features:
|
||||
if feature["image_grid_thw"] is None:
|
||||
image_grid_thw.append(torch.zeros(0, dtype=torch.long))
|
||||
else:
|
||||
image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long))
|
||||
|
||||
extra_features["image_grid_thw"] = torch.cat(pixel_values, dim=0)
|
||||
if extra_features["image_grid_thw"].numel() == 0:
|
||||
extra_features["image_grid_thw"] = None
|
||||
|
||||
features = [{key: feature[key] for key in feature if key not in extra_features.keys()} for feature in features]
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
features.update({key: value for key, value in extra_features.items() if value is not None})
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
|
||||
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for 4d attention mask.
|
||||
"""
|
||||
@ -117,7 +122,7 @@ class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
||||
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
@ -152,7 +157,7 @@ class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
||||
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
|
@ -16,16 +16,16 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from .data_utils import SLOTS
|
||||
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
|
||||
from .tool_utils import get_tool_utils
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Optional[Literal["default", "glm4"]] = None
|
||||
tool_format: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS: ...
|
||||
@ -81,12 +81,7 @@ class StringFormatter(Formatter):
|
||||
@dataclass
|
||||
class FunctionFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
if self.tool_format == "default":
|
||||
self.slots = DefaultToolUtils.get_function_slots() + self.slots
|
||||
elif self.tool_format == "glm4":
|
||||
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
|
||||
else:
|
||||
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
@ -119,22 +114,15 @@ class FunctionFormatter(Formatter):
|
||||
@dataclass
|
||||
class ToolFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
if self.tool_format == "default":
|
||||
self._tool_formatter = DefaultToolUtils.tool_formatter
|
||||
self._tool_extractor = DefaultToolUtils.tool_extractor
|
||||
elif self.tool_format == "glm4":
|
||||
self._tool_formatter = GLM4ToolUtils.tool_formatter
|
||||
self._tool_extractor = GLM4ToolUtils.tool_extractor
|
||||
else:
|
||||
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||
self.tool_utils = get_tool_utils(self.tool_format)
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
tools = json.loads(content)
|
||||
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
except json.JSONDecodeError:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
return self._tool_extractor(content)
|
||||
return self.tool_utils.tool_extractor(content)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from PIL.Image import Image
|
||||
@ -27,32 +28,33 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
|
||||
|
||||
Returns: (qwen2-vl)
|
||||
pixel_values: tensor with shape (num_patches, patch_dim)
|
||||
image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height
|
||||
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
||||
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
if len(images) != 0:
|
||||
image_inputs = image_processor(images=images, return_tensors="pt")
|
||||
else:
|
||||
image = Image.new("RGB", (56, 56), (255, 255, 255))
|
||||
else: # add NoneType for fake images
|
||||
image = Image.new("RGB", (64, 64), (255, 255, 255))
|
||||
image_inputs = image_processor(images=[image], return_tensors="pt")
|
||||
if "image_grid_thw" in image_inputs: # fake image for qwen2-vl
|
||||
image_inputs["image_grid_thw"][0][0] = 0
|
||||
image_inputs = {key: None for key in image_inputs.keys()}
|
||||
|
||||
return image_inputs
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
|
||||
def _get_paligemma_token_type_ids(
|
||||
images: Sequence["ImageObject"], input_len: int, processor: "ProcessorMixin"
|
||||
) -> List[List[int]]:
|
||||
r"""
|
||||
Gets paligemma token type ids for computing loss.
|
||||
|
||||
Returns:
|
||||
token_type_ids: shape (1, seq_len)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_seq_length: int = getattr(image_processor, "image_seq_length")
|
||||
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
|
||||
num_images = len(images)
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
||||
return [[0] * image_seqlen + [1] * (input_len - image_seqlen)]
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
@ -74,6 +76,7 @@ class BasePlugin:
|
||||
self,
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
images: Sequence["ImageObject"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
@ -93,18 +96,6 @@ class BasePlugin:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
model_inputs: Dict[str, List[Any]],
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
r"""
|
||||
Appends multimodal inputs to model inputs for VLMs.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class LlavaPlugin(BasePlugin):
|
||||
def process_messages(
|
||||
@ -113,21 +104,21 @@ class LlavaPlugin(BasePlugin):
|
||||
images: Sequence["ImageObject"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
image_count = 0
|
||||
new_messages = []
|
||||
num_images = 0
|
||||
image_seqlen = getattr(processor, "image_seqlen")
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_count += 1
|
||||
if image_count > 1:
|
||||
raise ValueError("Llava model only accepts one image per sample.")
|
||||
|
||||
num_images += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
|
||||
content = content.replace("{{image}}", self.image_token)
|
||||
new_messages.append({"role": message["role"], "content": content})
|
||||
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
|
||||
return new_messages
|
||||
if len(images) != num_images:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
return messages
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
@ -137,17 +128,6 @@ class LlavaPlugin(BasePlugin):
|
||||
) -> Dict[str, Any]:
|
||||
return _get_mm_inputs(images, processor)
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
model_inputs: Dict[str, List[Any]],
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
||||
for key, value in mm_inputs.items():
|
||||
model_inputs[key].append(value[0])
|
||||
|
||||
|
||||
class PaliGemmaPlugin(BasePlugin):
|
||||
def process_messages(
|
||||
@ -156,34 +136,35 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
images: Sequence["ImageObject"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
image_count = 0
|
||||
new_messages = []
|
||||
num_images = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_count += 1
|
||||
if image_count > 1:
|
||||
raise ValueError("PaliGemma model only accepts one image per sample.")
|
||||
num_images += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
|
||||
message["content"] = content.replace("{{image}}", "")
|
||||
|
||||
new_messages.append({"role": message["role"], "content": content})
|
||||
if len(images) != num_images:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
return new_messages
|
||||
return messages
|
||||
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
images: Sequence["ImageObject"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_seq_length: int = getattr(image_processor, "image_seq_length")
|
||||
num_images = len(images)
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
input_ids = [image_token_id] * image_seq_length + input_ids
|
||||
input_ids = [image_token_id] * image_seqlen + input_ids
|
||||
if labels is not None:
|
||||
labels = [IGNORE_INDEX] * image_seq_length + labels
|
||||
labels = [IGNORE_INDEX] * image_seqlen + labels
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
@ -195,21 +176,10 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
) -> Dict[str, Any]:
|
||||
mm_inputs = _get_mm_inputs(images, processor)
|
||||
for feature_name, feature_length in feature_seqlens.items():
|
||||
mm_inputs[feature_name] = _get_paligemma_token_type_ids(feature_length, processor)
|
||||
mm_inputs[feature_name] = _get_paligemma_token_type_ids(images, feature_length, processor)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
model_inputs: Dict[str, List[Any]],
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
||||
for key, value in mm_inputs.items():
|
||||
model_inputs[key].append(value[0])
|
||||
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
def process_messages(
|
||||
@ -223,23 +193,26 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
if len(images) > 0:
|
||||
image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"]
|
||||
|
||||
index = 0
|
||||
new_messages = []
|
||||
num_images = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
"<|vision_start|>{}<|vision_end|>".format(
|
||||
self.image_token * (image_grid_thw[index].prod() // merge_length)
|
||||
self.image_token * (image_grid_thw[num_images].prod() // merge_length)
|
||||
),
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
num_images += 1
|
||||
|
||||
new_messages.append({"role": message["role"], "content": content})
|
||||
message["content"] = content
|
||||
|
||||
return new_messages
|
||||
if len(images) != num_images:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
return messages
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
@ -249,17 +222,6 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
) -> Dict[str, Any]:
|
||||
return _get_mm_inputs(images, processor)
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
model_inputs: Dict[str, List[Any]],
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
||||
for key, value in mm_inputs.items():
|
||||
model_inputs[key].append(value) # support multi-image
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
@ -270,7 +232,8 @@ PLUGINS = {
|
||||
|
||||
|
||||
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
|
||||
if name not in PLUGINS:
|
||||
raise ValueError("{} not found.".format(name))
|
||||
plugin_class = PLUGINS.get(name, None)
|
||||
if plugin_class is None:
|
||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
||||
|
||||
return PLUGINS[name](image_token)
|
||||
return plugin_class(image_token)
|
||||
|
@ -50,7 +50,7 @@ def get_preprocess_and_print_func(
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not do_generate:
|
||||
if data_args.packing:
|
||||
if data_args.neat_packing:
|
||||
if data_args.neat_packing: # hack datasets to have int32 attention mask
|
||||
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
||||
|
||||
def __init__(self, data, **kwargs):
|
||||
@ -67,6 +67,7 @@ def get_preprocess_and_print_func(
|
||||
preprocess_packed_supervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
else:
|
||||
|
@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
@ -36,11 +37,12 @@ def _encode_feedback_example(
|
||||
kl_response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["Image"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool, Dict[str, Any]]:
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = prompt + [response[0]]
|
||||
@ -53,6 +55,8 @@ def _encode_feedback_example(
|
||||
else:
|
||||
kl_messages = prompt + [kl_response[1]]
|
||||
|
||||
messages = template.mm_plugin.process_messages(messages, images, processor)
|
||||
kl_messages = template.mm_plugin.process_messages(kl_messages, images, processor)
|
||||
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
||||
|
||||
@ -60,8 +64,8 @@ def _encode_feedback_example(
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
|
||||
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor)
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
|
||||
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, tokenizer, processor)
|
||||
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
@ -74,8 +78,15 @@ def _encode_feedback_example(
|
||||
labels = [IGNORE_INDEX] * source_len + response_ids
|
||||
kl_input_ids = kl_prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
||||
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=images,
|
||||
feature_seqlens={
|
||||
"token_type_ids": len(input_ids),
|
||||
"kl_token_type_ids": len(kl_input_ids),
|
||||
},
|
||||
processor=processor,
|
||||
)
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs
|
||||
|
||||
|
||||
def preprocess_feedback_dataset(
|
||||
@ -93,13 +104,13 @@ def preprocess_feedback_dataset(
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||
prompt=prompt,
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag, extra_inputs = _encode_feedback_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
kl_response=kl_response[i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -112,15 +123,8 @@ def preprocess_feedback_dataset(
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
template.mm_plugin.process_model_inputs(
|
||||
model_inputs=model_inputs,
|
||||
images=examples["images"][i],
|
||||
feature_seqlens={
|
||||
"token_type_ids": len(input_ids),
|
||||
"kl_token_type_ids": len(kl_input_ids),
|
||||
},
|
||||
processor=processor,
|
||||
)
|
||||
for key, value in extra_inputs.items():
|
||||
model_inputs[key].append(value)
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
|
@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
@ -35,13 +36,14 @@ def _encode_pairwise_example(
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["Image"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
chosen_messages = prompt + [response[0]]
|
||||
rejected_messages = prompt + [response[1]]
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], Dict[str, Any]]:
|
||||
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor)
|
||||
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor)
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
||||
|
||||
@ -49,7 +51,7 @@ def _encode_pairwise_example(
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
|
||||
# consider the response is more important
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
@ -60,8 +62,15 @@ def _encode_pairwise_example(
|
||||
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
||||
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=images,
|
||||
feature_seqlens={
|
||||
"chosen_token_type_ids": len(chosen_input_ids),
|
||||
"rejected_token_type_ids": len(rejected_input_ids),
|
||||
},
|
||||
processor=processor,
|
||||
)
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
@ -78,12 +87,12 @@ def preprocess_pairwise_dataset(
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||
prompt=prompt,
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels, extra_inputs = _encode_pairwise_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -95,15 +104,8 @@ def preprocess_pairwise_dataset(
|
||||
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||
model_inputs["rejected_labels"].append(rejected_labels)
|
||||
template.mm_plugin.process_model_inputs(
|
||||
model_inputs=model_inputs,
|
||||
images=examples["images"][i],
|
||||
feature_seqlens={
|
||||
"chosen_token_type_ids": len(chosen_input_ids),
|
||||
"rejected_token_type_ids": len(rejected_input_ids),
|
||||
},
|
||||
processor=processor,
|
||||
)
|
||||
for key, value in extra_inputs.items():
|
||||
model_inputs[key].append(value)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -21,6 +21,7 @@ from .processor_utils import greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
@ -35,19 +36,18 @@ def _encode_supervised_example(
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["Image"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
train_on_prompt: bool,
|
||||
mask_history: bool,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
messages = prompt + response
|
||||
input_ids, labels = [], []
|
||||
input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor)
|
||||
|
||||
) -> Tuple[List[int], List[int], Dict[str, Any]]:
|
||||
messages = template.mm_plugin.process_messages(prompt + response, images, processor)
|
||||
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor)
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||
total_length = 1 if template.efficient_eos else 0
|
||||
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
|
||||
if mask_history:
|
||||
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
||||
|
||||
@ -83,7 +83,10 @@ def _encode_supervised_example(
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
return input_ids, labels
|
||||
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
|
||||
)
|
||||
return input_ids, labels, extra_inputs
|
||||
|
||||
|
||||
def preprocess_supervised_dataset(
|
||||
@ -101,12 +104,12 @@ def preprocess_supervised_dataset(
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=prompt,
|
||||
input_ids, labels, extra_inputs = _encode_supervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -117,12 +120,8 @@ def preprocess_supervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
template.mm_plugin.process_model_inputs(
|
||||
model_inputs=model_inputs,
|
||||
images=examples["images"][i],
|
||||
feature_seqlens={"token_type_ids": len(input_ids)},
|
||||
processor=processor,
|
||||
)
|
||||
for key, value in extra_inputs.items():
|
||||
model_inputs[key].append(value)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@ -131,10 +130,15 @@ def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[Any]]:
|
||||
# TODO: use `position_ids` to achieve packing
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
if processor is not None:
|
||||
raise NotImplementedError("`packing` have not been implemented for multimodal datasets.")
|
||||
|
||||
valid_num = 0
|
||||
batch_input_ids, batch_labels = [], []
|
||||
lengths = []
|
||||
@ -149,6 +153,7 @@ def preprocess_packed_supervised_dataset(
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=None,
|
||||
|
@ -21,6 +21,7 @@ from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
@ -35,25 +36,30 @@ def _encode_unsupervised_example(
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["Image"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
) -> Tuple[List[int], List[int], Dict[str, Any]]:
|
||||
if len(response) == 1:
|
||||
messages = prompt + response
|
||||
else:
|
||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
messages = template.mm_plugin.process_messages(messages, images, processor)
|
||||
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, tokenizer, processor)
|
||||
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, tokenizer, processor)
|
||||
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
|
||||
input_ids = input_ids[:source_len]
|
||||
labels = labels[:target_len]
|
||||
return input_ids, labels
|
||||
extra_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=images, feature_seqlens={"token_type_ids": len(input_ids)}, processor=processor
|
||||
)
|
||||
return input_ids, labels, extra_inputs
|
||||
|
||||
|
||||
def preprocess_unsupervised_dataset(
|
||||
@ -70,12 +76,12 @@ def preprocess_unsupervised_dataset(
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
||||
input_ids, labels = _encode_unsupervised_example(
|
||||
prompt=prompt,
|
||||
input_ids, labels, extra_inputs = _encode_unsupervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -84,12 +90,8 @@ def preprocess_unsupervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
template.mm_plugin.process_model_inputs(
|
||||
model_inputs=model_inputs,
|
||||
images=examples["images"][i],
|
||||
feature_seqlens={"token_type_ids": len(input_ids)},
|
||||
processor=processor,
|
||||
)
|
||||
for key, value in extra_inputs.items():
|
||||
model_inputs[key].append(value)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -15,6 +15,8 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role
|
||||
@ -347,6 +349,11 @@ def get_template_and_fix_tokenizer(
|
||||
name: Optional[str] = None,
|
||||
tool_format: Optional[str] = None,
|
||||
) -> Template:
|
||||
if name == "qwen2_vl":
|
||||
require_version(
|
||||
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||
)
|
||||
|
||||
if name is None:
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
else:
|
||||
@ -357,8 +364,8 @@ def get_template_and_fix_tokenizer(
|
||||
if tool_format is not None:
|
||||
logger.info("Using tool format: {}.".format(tool_format))
|
||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||
template.format_tools = ToolFormatter(tool_format=tool_format)
|
||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
|
||||
template.format_tools = ToolFormatter(tool_format=tool_format)
|
||||
|
||||
stop_words = template.stop_words
|
||||
if template.replace_eos:
|
||||
|
@ -138,3 +138,17 @@ class GLM4ToolUtils(ToolUtils):
|
||||
return content
|
||||
|
||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
}
|
||||
|
||||
|
||||
def get_tool_utils(name: str) -> "ToolUtils":
|
||||
tool_utils = TOOLS.get(name, None)
|
||||
if tool_utils is None:
|
||||
raise ValueError("Tool utils `{}` not found.".format(name))
|
||||
|
||||
return tool_utils
|
||||
|
@ -195,6 +195,9 @@ def is_gpu_or_npu_available() -> bool:
|
||||
|
||||
|
||||
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
r"""
|
||||
Casts a torch tensor or a numpy array to a numpy array.
|
||||
"""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = inputs.cpu()
|
||||
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
|
||||
@ -206,6 +209,9 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
|
||||
|
||||
def skip_check_imports() -> None:
|
||||
r"""
|
||||
Avoids flash attention import error in custom model files.
|
||||
"""
|
||||
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from llamafactory.train.tuner import run_exp
|
||||
from llamafactory.train.tuner import run_exp # use absolute import
|
||||
|
||||
|
||||
def launch():
|
||||
|
@ -25,6 +25,7 @@ from .model_utils.misc import register_autoclass
|
||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
from .model_utils.visual import get_image_seqlen
|
||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||
|
||||
|
||||
@ -65,6 +66,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
@ -96,6 +98,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||
except Exception:
|
||||
processor = None
|
||||
|
||||
|
@ -82,7 +82,7 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
|
||||
|
||||
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Casts projector output to half precision for quantized VLMs.
|
||||
Casts projector output to half precision for fine-tuning quantized VLMs.
|
||||
"""
|
||||
|
||||
def _mm_projector_forward_post_hook(
|
||||
@ -136,6 +136,22 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
||||
return forbidden_modules
|
||||
|
||||
|
||||
def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
r"""
|
||||
Computes the number of special tokens per image.
|
||||
"""
|
||||
if getattr(config, "model_type", None) == "llava":
|
||||
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
|
||||
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
|
||||
image_seqlen += 1
|
||||
elif getattr(config, "model_type", None) == "paligemma":
|
||||
image_seqlen = config.vision_config.num_image_tokens
|
||||
elif getattr(config, "model_type", None) == "qwen2_vl": # variable length
|
||||
image_seqlen = -1
|
||||
|
||||
return image_seqlen
|
||||
|
||||
|
||||
def patch_target_modules(
|
||||
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||
) -> Union[str, List[str]]:
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import CustomDataCollatorForSeq2Seq, get_dataset
|
||||
from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import fix_valuehead_checkpoint
|
||||
@ -45,7 +45,7 @@ def run_ppo(
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
data_collator = CustomDataCollatorForSeq2Seq(tokenizer=tokenizer)
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(tokenizer=tokenizer)
|
||||
|
||||
# Create reference model and reward model
|
||||
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
|
||||
|
@ -13,8 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -26,7 +25,7 @@ from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
@ -34,13 +33,20 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
MESSAGES = [
|
||||
MM_MESSAGES = [
|
||||
{"role": "user", "content": "<image>What is in this image?"},
|
||||
{"role": "assistant", "content": "A cat."},
|
||||
]
|
||||
|
||||
TEXT_MESSAGES = [
|
||||
{"role": "user", "content": "How are you"},
|
||||
{"role": "assistant", "content": "I am fine!"},
|
||||
]
|
||||
|
||||
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
|
||||
|
||||
NO_IMAGES = []
|
||||
|
||||
INPUT_IDS = [0, 1, 2, 3, 4]
|
||||
|
||||
LABELS = [0, 1, 2, 3, 4]
|
||||
@ -53,99 +59,110 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
return image_processor(images=IMAGES, return_tensors="pt")
|
||||
|
||||
|
||||
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]):
|
||||
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
|
||||
assert batch_a.keys() == batch_b.keys()
|
||||
for key in batch_a.keys():
|
||||
if isinstance(batch_a[key], list):
|
||||
assert len(batch_a[key]) == len(batch_b[key])
|
||||
for i in range(len(batch_a[key])):
|
||||
if isinstance(batch_a[key][i], torch.Tensor):
|
||||
assert torch.allclose(batch_a[key][i], batch_b[key][i], rtol=1e-4, atol=1e-5)
|
||||
else:
|
||||
assert batch_a[key][i] == batch_b[key][i]
|
||||
elif isinstance(batch_a[key], torch.Tensor):
|
||||
if isinstance(batch_a[key], torch.Tensor):
|
||||
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
assert batch_a[key] == batch_b[key]
|
||||
|
||||
|
||||
def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenizer", "ProcessorMixin"]:
|
||||
model_args = ModelArguments(model_name_or_path=model_name_or_path)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
return tokenizer_module["tokenizer"], tokenizer_module["processor"]
|
||||
|
||||
|
||||
def test_base_plugin():
|
||||
model_args = ModelArguments(model_name_or_path=TINY_LLAMA)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
processor = tokenizer_module["processor"]
|
||||
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
||||
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
||||
model_inputs = defaultdict(list)
|
||||
base_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
||||
|
||||
assert base_plugin.process_messages(MESSAGES, IMAGES, processor)
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
# test mm_messages
|
||||
assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(base_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), {})
|
||||
_is_close(model_inputs, {})
|
||||
# test text_messages
|
||||
assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {})
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
model_args = ModelArguments(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
processor = tokenizer_module["processor"]
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
image_seqlen = 576
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()}
|
||||
expected_mm_messages = [
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
|
||||
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
|
||||
model_inputs = defaultdict(list)
|
||||
llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
||||
|
||||
assert llava_plugin.process_messages(MESSAGES, IMAGES, processor)
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
# test mm_messages
|
||||
assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(model_inputs, expected_model_inputs)
|
||||
# test text_messages
|
||||
assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor), {"pixel_values": None})
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_paligemma_plugin():
|
||||
model_args = ModelArguments(model_name_or_path="google/paligemma-3b-pt-224")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
processor = tokenizer_module["processor"]
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_seq_length: int = getattr(image_processor, "image_seq_length")
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
|
||||
image_seqlen = 256
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
mm_inputs["token_type_ids"] = [[0] * image_seq_length + [1] * (1024 - image_seq_length)]
|
||||
expected_model_inputs = {key: [value[0]] for key, value in mm_inputs.items()}
|
||||
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seq_length + INPUT_IDS
|
||||
expected_labels = [-100] * image_seq_length + LABELS
|
||||
mm_inputs["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
|
||||
expected_mm_messages = [
|
||||
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
|
||||
]
|
||||
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
|
||||
expected_labels = [-100] * image_seqlen + LABELS
|
||||
|
||||
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
|
||||
model_inputs = defaultdict(list)
|
||||
paligemma_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
||||
|
||||
assert paligemma_plugin.process_messages(MESSAGES, IMAGES, processor)
|
||||
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (
|
||||
# test mm_messages
|
||||
assert paligemma_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(model_inputs, expected_model_inputs)
|
||||
# test text_messages
|
||||
assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
|
||||
INPUT_IDS,
|
||||
LABELS,
|
||||
)
|
||||
_is_close(
|
||||
paligemma_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
|
||||
{"pixel_values": None, "token_type_ids": [[1] * 1024]},
|
||||
)
|
||||
|
||||
|
||||
def test_qwen2_vl_plugin():
|
||||
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
processor = tokenizer_module["processor"]
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||
image_seqlen = 4
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
expected_model_inputs = {key: [value] for key, value in mm_inputs.items()}
|
||||
expected_mm_messages = [
|
||||
{
|
||||
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
|
||||
for key, value in message.items()
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
|
||||
llava_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||
model_inputs = defaultdict(list)
|
||||
llava_plugin.process_model_inputs(model_inputs, IMAGES, FEATURE_SEQLENS, processor)
|
||||
|
||||
assert llava_plugin.process_messages(MESSAGES, IMAGES, processor)
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
_is_close(model_inputs, expected_model_inputs)
|
||||
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||
# test mm_messages
|
||||
assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, FEATURE_SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(
|
||||
qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, FEATURE_SEQLENS, processor),
|
||||
{"pixel_values": None, "image_grid_thw": None},
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user