From 7138b43873c09ebf07924992001525044da429c9 Mon Sep 17 00:00:00 2001 From: fzc8578 <1428195643@qq.com> Date: Fri, 10 Jan 2025 20:27:06 +0800 Subject: [PATCH] fix some Former-commit-id: 2ee8ba2f390551af1b865cfa813f5c8b7bbb41c5 --- src/llamafactory/data/collator.py | 6 +- src/llamafactory/data/mm_plugin.py | 85 +++++++++++--------- src/llamafactory/data/template.py | 4 +- src/llamafactory/model/model_utils/visual.py | 2 +- src/llamafactory/train/sft/trainer.py | 1 - 5 files changed, 51 insertions(+), 47 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 3011dc2b..a3f9dfd1 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -149,14 +149,14 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features.update(mm_inputs) if isinstance(features.get("pixel_values"), list): # for pixtral inputs features = features.data # use default_collate() instead of BatchEncoding.to() - - if "image_bound" in features: # for minicpmv inputs + + if "image_bound" in features: # for minicpmv inputs features = self.template.mm_plugin.pad_data(features) new_features = {} new_features.update({"data": features}) new_features.update(features) features = new_features - + return features diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index cf002948..8c2a4dd0 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1,8 +1,8 @@ import math +import re from copy import deepcopy from io import BytesIO from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union -import re import numpy as np import torch @@ -276,38 +276,39 @@ class CpmOPlugin(BasePlugin): content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) message["content"] = content.replace("{{image}}", "(./)") - - if num_image_tokens>0: - mm_inputs = self._get_mm_inputs(images, videos, processor) - - pattern = "(./)" - images, image_sizes, tgt_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"] - input_ids_list = [] - image_bounds_list = [] + if num_image_tokens > 0: + mm_inputs = self._get_mm_inputs(images, videos, processor) + + pattern = "(./)" + images, image_sizes, _ = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"] + image_index = 0 for index, message in enumerate(messages): - text = message['content'] + text = message["content"] image_tags = re.findall(pattern, text) text_chunks = text.split(pattern) final_text = "" for i in range(len(image_tags)): - final_text = final_text + text_chunks[i] + \ - image_processor.get_slice_image_placeholder( - image_sizes[image_index][i], + final_text = ( + final_text + + text_chunks[i] + + image_processor.get_slice_image_placeholder( + image_sizes[image_index][i], i, image_processor.max_slice_nums, image_processor.use_image_id, ) + ) image_index += 1 final_text += text_chunks[-1] - messages[index]['content'] = final_text + messages[index]["content"] = final_text if len(images) != num_image_tokens: raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") return messages - + @override def _get_mm_inputs( self, @@ -316,25 +317,26 @@ class CpmOPlugin(BasePlugin): processor: "ProcessorMixin", **kwargs, ) -> Dict[str, "torch.Tensor"]: - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - + mm_inputs = {} - + if len(images) != 0: images = self._regularize_images( images, image_resolution=getattr(processor, "image_resolution", 512 * 512), ) if "valid_image_nums_ls" in kwargs: - valid_image_nums_ls = kwargs['valid_image_nums_ls'] + valid_image_nums_ls = kwargs["valid_image_nums_ls"] new_images = [] idx = 0 for valid_image_nums in valid_image_nums_ls: - new_images.append(images[idx:idx+valid_image_nums]) + new_images.append(images[idx : idx + valid_image_nums]) idx += valid_image_nums images = new_images - image_inputs = image_processor(images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt") + image_inputs = image_processor( + images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt" + ) mm_inputs.update(image_inputs) if len(videos) != 0: @@ -344,26 +346,26 @@ class CpmOPlugin(BasePlugin): video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 64), ) - + return mm_inputs - + def trim_and_pad(self, seq, padding_value=0): - return pad_sequence([s for s in seq], batch_first=True, padding_value=padding_value) - + return pad_sequence(seq, batch_first=True, padding_value=padding_value) + def pad_data(self, features): - features['position_ids'] = [torch.arange(input_ids.size(0)).long() for input_ids in features['input_ids']] - features['input_ids'] = self.trim_and_pad( - [input_ids for input_ids in features['input_ids']], + features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]] + features["input_ids"] = self.trim_and_pad( + features["input_ids"], ) - features['position_ids'] = self.trim_and_pad( - [position_ids for position_ids in features['position_ids']], + features["position_ids"] = self.trim_and_pad( + features["position_ids"], ) - features['labels'] = self.trim_and_pad( - [labels for labels in features['labels']], + features["labels"] = self.trim_and_pad( + features["labels"], padding_value=-100, ) - features['attention_mask'] = self.trim_and_pad( - [attention_mask for attention_mask in features['attention_mask']], + features["attention_mask"] = self.trim_and_pad( + features["attention_mask"], ) return features @@ -379,11 +381,12 @@ class CpmOPlugin(BasePlugin): ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) image_bounds_list = [] - position_ids = [] valid_image_nums_ls = [] for input_ids in batch_ids: input_ids_ = torch.tensor(input_ids) - start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (input_ids_ == processor.tokenizer.slice_start_id) + start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( + input_ids_ == processor.tokenizer.slice_start_id + ) end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id) image_start_tokens = torch.where(start_cond)[0] image_start_tokens += 1 @@ -398,10 +401,12 @@ class CpmOPlugin(BasePlugin): ) image_bounds_list.append(image_bounds) mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls) - mm_inputs.update({ - "image_bound": image_bounds_list, - }) - return mm_inputs + mm_inputs.update( + { + "image_bound": image_bounds_list, + } + ) + return mm_inputs class LlavaPlugin(BasePlugin): diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 68e00b1b..285714cd 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -570,13 +570,13 @@ _register_template( _register_template( name="cpm_o", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), format_observation=StringFormatter( slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] ), format_tools=ToolFormatter(tool_format="qwen"), - format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], mm_plugin=get_mm_plugin(name="cpm_o", image_token="", video_token="