mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
add some
Former-commit-id: 096a6cb67a7dfd14a6e339d96baab78c12d36a87
This commit is contained in:
parent
b9eeaa9706
commit
165fe8e219
@ -149,9 +149,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features.update(mm_inputs)
|
||||
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||
if "image_bound" in features:
|
||||
input_ids, position_ids = features['input_ids'], features['position_ids']
|
||||
features['position_ids'] = F.pad(position_ids, (0, input_ids.shape[-1] - position_ids.shape[-1]))
|
||||
|
||||
if "image_bound" in features: # for minicpmv inputs
|
||||
features = self.template.mm_plugin.pad_data(features)
|
||||
new_features = {}
|
||||
new_features.update({"data": features})
|
||||
new_features.update(features)
|
||||
|
@ -6,6 +6,7 @@ import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformers.image_utils import get_image_size, to_numpy_array
|
||||
from typing_extensions import override
|
||||
|
||||
@ -297,7 +298,6 @@ class CpmOPlugin(BasePlugin):
|
||||
image_index += 1
|
||||
final_text += text_chunks[-1]
|
||||
messages[index]['content'] = final_text
|
||||
# print(messages)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
@ -310,6 +310,7 @@ class CpmOPlugin(BasePlugin):
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
**kwargs,
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
@ -321,6 +322,14 @@ class CpmOPlugin(BasePlugin):
|
||||
images,
|
||||
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||
)
|
||||
if "valid_image_nums_ls" in kwargs:
|
||||
valid_image_nums_ls = kwargs['valid_image_nums_ls']
|
||||
new_images = []
|
||||
idx = 0
|
||||
for valid_image_nums in valid_image_nums_ls:
|
||||
new_images.append(images[idx:idx+valid_image_nums])
|
||||
idx += valid_image_nums
|
||||
images = new_images
|
||||
image_inputs = image_processor(images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt")
|
||||
mm_inputs.update(image_inputs)
|
||||
|
||||
@ -334,6 +343,26 @@ class CpmOPlugin(BasePlugin):
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def trim_and_pad(self, seq, padding_value=0):
|
||||
return pad_sequence([s for s in seq], batch_first=True, padding_value=padding_value)
|
||||
|
||||
def pad_data(self, features):
|
||||
features['position_ids'] = [torch.arange(input_ids.size(0)).long() for input_ids in features['input_ids']]
|
||||
features['input_ids'] = self.trim_and_pad(
|
||||
[input_ids for input_ids in features['input_ids']],
|
||||
)
|
||||
features['position_ids'] = self.trim_and_pad(
|
||||
[position_ids for position_ids in features['position_ids']],
|
||||
)
|
||||
features['labels'] = self.trim_and_pad(
|
||||
[labels for labels in features['labels']],
|
||||
padding_value=-100,
|
||||
)
|
||||
features['attention_mask'] = self.trim_and_pad(
|
||||
[attention_mask for attention_mask in features['attention_mask']],
|
||||
)
|
||||
return features
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
@ -345,9 +374,9 @@ class CpmOPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
image_bounds_list = []
|
||||
position_ids = []
|
||||
valid_image_nums_ls = []
|
||||
for input_ids in batch_ids:
|
||||
input_ids_ = torch.tensor(input_ids)
|
||||
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (input_ids_ == processor.tokenizer.slice_start_id)
|
||||
@ -356,6 +385,7 @@ class CpmOPlugin(BasePlugin):
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(end_cond)[0]
|
||||
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
||||
valid_image_nums_ls.append(valid_image_nums)
|
||||
image_bounds = torch.hstack(
|
||||
[
|
||||
image_start_tokens[:valid_image_nums].unsqueeze(-1),
|
||||
@ -363,14 +393,9 @@ class CpmOPlugin(BasePlugin):
|
||||
]
|
||||
)
|
||||
image_bounds_list.append(image_bounds)
|
||||
position_ids_ = list(range(input_ids_.size(0)))
|
||||
# print(input_ids_.shape, len(position_ids_)
|
||||
position_ids.append(position_ids_)
|
||||
#TODO add pad
|
||||
position_ids = torch.tensor(position_ids, dtype=torch.int64)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
|
||||
mm_inputs.update({
|
||||
"image_bound": image_bounds_list,
|
||||
"position_ids": position_ids,
|
||||
})
|
||||
return mm_inputs
|
||||
|
||||
|
@ -100,7 +100,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
patch_processor(processor, config, tokenizer, model_args)
|
||||
except Exception as e:
|
||||
logger.debug(f"Processor was not found: {e}.")
|
||||
logger.info(f"Processor was not found: {e}.")
|
||||
processor = None
|
||||
|
||||
# Avoid load tokenizer, see:
|
||||
|
@ -46,6 +46,9 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
forbidden_modules.add("visual")
|
||||
elif model_type in ["minicpmv"]:
|
||||
forbidden_modules.add("vpm")
|
||||
forbidden_modules.add("apm")
|
||||
forbidden_modules.add("resampler")
|
||||
forbidden_modules.add("tts")
|
||||
else:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
|
@ -145,7 +145,11 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
||||
|
||||
elif model_type == "minicpmv":
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
print("******************", model_type)
|
||||
forbidden_modules.add("vpm")
|
||||
forbidden_modules.add("apm")
|
||||
forbidden_modules.add("resampler")
|
||||
forbidden_modules.add("tts")
|
||||
|
||||
return forbidden_modules
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user