From 165fe8e219e697ab94a22f963c86ac83ae6d6f61 Mon Sep 17 00:00:00 2001 From: fzc8578 <1428195643@qq.com> Date: Fri, 10 Jan 2025 20:01:22 +0800 Subject: [PATCH] add some Former-commit-id: 096a6cb67a7dfd14a6e339d96baab78c12d36a87 --- src/llamafactory/data/collator.py | 8 ++-- src/llamafactory/data/mm_plugin.py | 41 ++++++++++++++++---- src/llamafactory/model/loader.py | 2 +- src/llamafactory/model/model_utils/misc.py | 3 ++ src/llamafactory/model/model_utils/visual.py | 4 ++ 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 90abf34c..3011dc2b 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -149,14 +149,14 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features.update(mm_inputs) if isinstance(features.get("pixel_values"), list): # for pixtral inputs features = features.data # use default_collate() instead of BatchEncoding.to() - if "image_bound" in features: - input_ids, position_ids = features['input_ids'], features['position_ids'] - features['position_ids'] = F.pad(position_ids, (0, input_ids.shape[-1] - position_ids.shape[-1])) + + if "image_bound" in features: # for minicpmv inputs + features = self.template.mm_plugin.pad_data(features) new_features = {} new_features.update({"data": features}) new_features.update(features) features = new_features - + return features diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 013672c0..fe8cd450 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -6,6 +6,7 @@ import re import numpy as np import torch +from torch.nn.utils.rnn import pad_sequence from transformers.image_utils import get_image_size, to_numpy_array from typing_extensions import override @@ -297,7 +298,6 @@ class CpmOPlugin(BasePlugin): image_index += 1 final_text += text_chunks[-1] messages[index]['content'] = final_text - # print(messages) if len(images) != num_image_tokens: raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") @@ -310,6 +310,7 @@ class CpmOPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], processor: "ProcessorMixin", + **kwargs, ) -> Dict[str, "torch.Tensor"]: image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") @@ -321,6 +322,14 @@ class CpmOPlugin(BasePlugin): images, image_resolution=getattr(processor, "image_resolution", 512 * 512), ) + if "valid_image_nums_ls" in kwargs: + valid_image_nums_ls = kwargs['valid_image_nums_ls'] + new_images = [] + idx = 0 + for valid_image_nums in valid_image_nums_ls: + new_images.append(images[idx:idx+valid_image_nums]) + idx += valid_image_nums + images = new_images image_inputs = image_processor(images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt") mm_inputs.update(image_inputs) @@ -333,6 +342,26 @@ class CpmOPlugin(BasePlugin): ) return mm_inputs + + def trim_and_pad(self, seq, padding_value=0): + return pad_sequence([s for s in seq], batch_first=True, padding_value=padding_value) + + def pad_data(self, features): + features['position_ids'] = [torch.arange(input_ids.size(0)).long() for input_ids in features['input_ids']] + features['input_ids'] = self.trim_and_pad( + [input_ids for input_ids in features['input_ids']], + ) + features['position_ids'] = self.trim_and_pad( + [position_ids for position_ids in features['position_ids']], + ) + features['labels'] = self.trim_and_pad( + [labels for labels in features['labels']], + padding_value=-100, + ) + features['attention_mask'] = self.trim_and_pad( + [attention_mask for attention_mask in features['attention_mask']], + ) + return features @override def get_mm_inputs( @@ -345,9 +374,9 @@ class CpmOPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) - mm_inputs = self._get_mm_inputs(images, videos, processor) image_bounds_list = [] position_ids = [] + valid_image_nums_ls = [] for input_ids in batch_ids: input_ids_ = torch.tensor(input_ids) start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (input_ids_ == processor.tokenizer.slice_start_id) @@ -356,6 +385,7 @@ class CpmOPlugin(BasePlugin): image_start_tokens += 1 image_end_tokens = torch.where(end_cond)[0] valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + valid_image_nums_ls.append(valid_image_nums) image_bounds = torch.hstack( [ image_start_tokens[:valid_image_nums].unsqueeze(-1), @@ -363,14 +393,9 @@ class CpmOPlugin(BasePlugin): ] ) image_bounds_list.append(image_bounds) - position_ids_ = list(range(input_ids_.size(0))) - # print(input_ids_.shape, len(position_ids_) - position_ids.append(position_ids_) - #TODO add pad - position_ids = torch.tensor(position_ids, dtype=torch.int64) + mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls) mm_inputs.update({ "image_bound": image_bounds_list, - "position_ids": position_ids, }) return mm_inputs diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 022cce06..6aadb866 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -100,7 +100,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) patch_processor(processor, config, tokenizer, model_args) except Exception as e: - logger.debug(f"Processor was not found: {e}.") + logger.info(f"Processor was not found: {e}.") processor = None # Avoid load tokenizer, see: diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index 4fb43883..9d204bd9 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -46,6 +46,9 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) forbidden_modules.add("visual") elif model_type in ["minicpmv"]: forbidden_modules.add("vpm") + forbidden_modules.add("apm") + forbidden_modules.add("resampler") + forbidden_modules.add("tts") else: forbidden_modules.add("vision_tower") diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index d0649514..49aadf2f 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -145,7 +145,11 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni elif model_type == "minicpmv": if finetuning_args.freeze_vision_tower: + print("******************", model_type) forbidden_modules.add("vpm") + forbidden_modules.add("apm") + forbidden_modules.add("resampler") + forbidden_modules.add("tts") return forbidden_modules