Former-commit-id: 2ee8ba2f390551af1b865cfa813f5c8b7bbb41c5
This commit is contained in:
fzc8578 2025-01-10 20:27:06 +08:00
parent aeb4f82ef2
commit 7138b43873
5 changed files with 51 additions and 47 deletions

View File

@ -149,14 +149,14 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features.update(mm_inputs) features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to() features = features.data # use default_collate() instead of BatchEncoding.to()
if "image_bound" in features: # for minicpmv inputs if "image_bound" in features: # for minicpmv inputs
features = self.template.mm_plugin.pad_data(features) features = self.template.mm_plugin.pad_data(features)
new_features = {} new_features = {}
new_features.update({"data": features}) new_features.update({"data": features})
new_features.update(features) new_features.update(features)
features = new_features features = new_features
return features return features

View File

@ -1,8 +1,8 @@
import math import math
import re
from copy import deepcopy from copy import deepcopy
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import re
import numpy as np import numpy as np
import torch import torch
@ -276,38 +276,39 @@ class CpmOPlugin(BasePlugin):
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
message["content"] = content.replace("{{image}}", "(<image>./</image>)") message["content"] = content.replace("{{image}}", "(<image>./</image>)")
if num_image_tokens>0:
mm_inputs = self._get_mm_inputs(images, videos, processor)
pattern = "(<image>./</image>)"
images, image_sizes, tgt_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"]
input_ids_list = [] if num_image_tokens > 0:
image_bounds_list = [] mm_inputs = self._get_mm_inputs(images, videos, processor)
pattern = "(<image>./</image>)"
images, image_sizes, _ = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"]
image_index = 0 image_index = 0
for index, message in enumerate(messages): for index, message in enumerate(messages):
text = message['content'] text = message["content"]
image_tags = re.findall(pattern, text) image_tags = re.findall(pattern, text)
text_chunks = text.split(pattern) text_chunks = text.split(pattern)
final_text = "" final_text = ""
for i in range(len(image_tags)): for i in range(len(image_tags)):
final_text = final_text + text_chunks[i] + \ final_text = (
image_processor.get_slice_image_placeholder( final_text
image_sizes[image_index][i], + text_chunks[i]
+ image_processor.get_slice_image_placeholder(
image_sizes[image_index][i],
i, i,
image_processor.max_slice_nums, image_processor.max_slice_nums,
image_processor.use_image_id, image_processor.use_image_id,
) )
)
image_index += 1 image_index += 1
final_text += text_chunks[-1] final_text += text_chunks[-1]
messages[index]['content'] = final_text messages[index]["content"] = final_text
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
@override @override
def _get_mm_inputs( def _get_mm_inputs(
self, self,
@ -316,25 +317,26 @@ class CpmOPlugin(BasePlugin):
processor: "ProcessorMixin", processor: "ProcessorMixin",
**kwargs, **kwargs,
) -> Dict[str, "torch.Tensor"]: ) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
images, images,
image_resolution=getattr(processor, "image_resolution", 512 * 512), image_resolution=getattr(processor, "image_resolution", 512 * 512),
) )
if "valid_image_nums_ls" in kwargs: if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs['valid_image_nums_ls'] valid_image_nums_ls = kwargs["valid_image_nums_ls"]
new_images = [] new_images = []
idx = 0 idx = 0
for valid_image_nums in valid_image_nums_ls: for valid_image_nums in valid_image_nums_ls:
new_images.append(images[idx:idx+valid_image_nums]) new_images.append(images[idx : idx + valid_image_nums])
idx += valid_image_nums idx += valid_image_nums
images = new_images images = new_images
image_inputs = image_processor(images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt") image_inputs = image_processor(
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
)
mm_inputs.update(image_inputs) mm_inputs.update(image_inputs)
if len(videos) != 0: if len(videos) != 0:
@ -344,26 +346,26 @@ class CpmOPlugin(BasePlugin):
video_fps=getattr(processor, "video_fps", 2.0), video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64), video_maxlen=getattr(processor, "video_maxlen", 64),
) )
return mm_inputs return mm_inputs
def trim_and_pad(self, seq, padding_value=0): def trim_and_pad(self, seq, padding_value=0):
return pad_sequence([s for s in seq], batch_first=True, padding_value=padding_value) return pad_sequence(seq, batch_first=True, padding_value=padding_value)
def pad_data(self, features): def pad_data(self, features):
features['position_ids'] = [torch.arange(input_ids.size(0)).long() for input_ids in features['input_ids']] features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
features['input_ids'] = self.trim_and_pad( features["input_ids"] = self.trim_and_pad(
[input_ids for input_ids in features['input_ids']], features["input_ids"],
) )
features['position_ids'] = self.trim_and_pad( features["position_ids"] = self.trim_and_pad(
[position_ids for position_ids in features['position_ids']], features["position_ids"],
) )
features['labels'] = self.trim_and_pad( features["labels"] = self.trim_and_pad(
[labels for labels in features['labels']], features["labels"],
padding_value=-100, padding_value=-100,
) )
features['attention_mask'] = self.trim_and_pad( features["attention_mask"] = self.trim_and_pad(
[attention_mask for attention_mask in features['attention_mask']], features["attention_mask"],
) )
return features return features
@ -379,11 +381,12 @@ class CpmOPlugin(BasePlugin):
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(images, videos)
image_bounds_list = [] image_bounds_list = []
position_ids = []
valid_image_nums_ls = [] valid_image_nums_ls = []
for input_ids in batch_ids: for input_ids in batch_ids:
input_ids_ = torch.tensor(input_ids) input_ids_ = torch.tensor(input_ids)
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (input_ids_ == processor.tokenizer.slice_start_id) start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
input_ids_ == processor.tokenizer.slice_start_id
)
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id) end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0] image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1 image_start_tokens += 1
@ -398,10 +401,12 @@ class CpmOPlugin(BasePlugin):
) )
image_bounds_list.append(image_bounds) image_bounds_list.append(image_bounds)
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls) mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
mm_inputs.update({ mm_inputs.update(
"image_bound": image_bounds_list, {
}) "image_bound": image_bounds_list,
return mm_inputs }
)
return mm_inputs
class LlavaPlugin(BasePlugin): class LlavaPlugin(BasePlugin):

View File

@ -570,13 +570,13 @@ _register_template(
_register_template( _register_template(
name="cpm_o", name="cpm_o",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"), format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"] slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
), ),
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="cpm_o", image_token="<image>", video_token="<video>"), mm_plugin=get_mm_plugin(name="cpm_o", image_token="<image>", video_token="<video>"),

View File

@ -142,7 +142,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
forbidden_modules.update({"visual.patch_embed", "visual.blocks", "model", "lm_head"}) forbidden_modules.update({"visual.patch_embed", "visual.blocks", "model", "lm_head"})
elif finetuning_args.freeze_vision_tower: elif finetuning_args.freeze_vision_tower:
forbidden_modules.add("visual") forbidden_modules.add("visual")
elif model_type == "minicpmv": elif model_type == "minicpmv":
if finetuning_args.freeze_vision_tower: if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vpm") forbidden_modules.add("vpm")

View File

@ -24,7 +24,6 @@ import numpy as np
import torch import torch
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from typing_extensions import override from typing_extensions import override
import copy
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX