From b5ef5059eee019b2469b968a6bfe9b1d1f115601 Mon Sep 17 00:00:00 2001 From: fzc8578 <1428195643@qq.com> Date: Sat, 4 Jan 2025 11:11:15 +0800 Subject: [PATCH] add some Former-commit-id: 79c2d7090cbf364063ea3608814ab18aa27fdc87 --- requirements.txt | 2 +- src/llamafactory/data/collator.py | 7 ++ src/llamafactory/data/mm_plugin.py | 126 +++++++++++++++++++++ src/llamafactory/data/template.py | 16 +++ src/llamafactory/extras/constants.py | 11 ++ src/llamafactory/extras/misc.py | 2 +- src/llamafactory/model/model_utils/misc.py | 2 + 7 files changed, 164 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index b47643e3..07034651 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers>=4.41.2,<=4.46.1 +transformers>=4.41.2 datasets>=2.16.0,<=3.1.0 accelerate>=0.34.0,<=1.0.1 peft>=0.11.1,<=0.12.0 diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index f360f0f5..90abf34c 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -149,6 +149,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features.update(mm_inputs) if isinstance(features.get("pixel_values"), list): # for pixtral inputs features = features.data # use default_collate() instead of BatchEncoding.to() + if "image_bound" in features: + input_ids, position_ids = features['input_ids'], features['position_ids'] + features['position_ids'] = F.pad(position_ids, (0, input_ids.shape[-1] - position_ids.shape[-1])) + new_features = {} + new_features.update({"data": features}) + new_features.update(features) + features = new_features return features diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 4e1f418a..0fbb5b1d 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -2,6 +2,7 @@ import math from copy import deepcopy from io import BytesIO from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union +import re import numpy as np import torch @@ -249,6 +250,130 @@ class BasePlugin: return {} +class CpmOPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + messages = deepcopy(messages) + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + num_image_tokens += 1 + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + + message["content"] = content.replace("{{image}}", "(./)") + + if num_image_tokens>0: + mm_inputs = self._get_mm_inputs(images, videos, processor) + + pattern = "(./)" + images, image_sizes, tgt_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"] + + input_ids_list = [] + image_bounds_list = [] + image_index = 0 + for index, message in enumerate(messages): + text = message['content'] + image_tags = re.findall(pattern, text) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(image_tags)): + final_text = final_text + text_chunks[i] + \ + image_processor.get_slice_image_placeholder( + image_sizes[image_index][i], + i, + image_processor.max_slice_nums, + image_processor.use_image_id, + ) + image_index += 1 + final_text += text_chunks[-1] + messages[index]['content'] = final_text + # print(messages) + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + return messages + + @override + def _get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: "ProcessorMixin", + ) -> Dict[str, "torch.Tensor"]: + + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + + mm_inputs = {} + + if len(images) != 0: + images = self._regularize_images( + images, + image_resolution=getattr(processor, "image_resolution", 512 * 512), + ) + image_inputs = image_processor(images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt") + mm_inputs.update(image_inputs) + + if len(videos) != 0: + videos = self._regularize_videos( + videos, + image_resolution=getattr(processor, "video_resolution", 128 * 128), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 64), + ) + + return mm_inputs + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + batch_ids: Sequence[List[int]], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + mm_inputs = self._get_mm_inputs(images, videos, processor) + image_bounds_list = [] + position_ids = [] + for input_ids in batch_ids: + input_ids_ = torch.tensor(input_ids) + start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (input_ids_ == processor.tokenizer.slice_start_id) + end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id) + image_start_tokens = torch.where(start_cond)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(end_cond)[0] + valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + image_bounds = torch.hstack( + [ + image_start_tokens[:valid_image_nums].unsqueeze(-1), + image_end_tokens[:valid_image_nums].unsqueeze(-1), + ] + ) + image_bounds_list.append(image_bounds) + position_ids_ = list(range(input_ids_.size(0))) + # print(input_ids_.shape, len(position_ids_) + position_ids.append(position_ids_) + position_ids = torch.tensor(position_ids, dtype=torch.int64) + mm_inputs.update({ + "image_bound": image_bounds_list, + "position_ids": position_ids, + }) + return mm_inputs + + class LlavaPlugin(BasePlugin): @override def process_messages( @@ -790,6 +915,7 @@ class MllamaPlugin(BasePlugin): PLUGINS = { "base": BasePlugin, + "cpm_o": CpmOPlugin, "llava": LlavaPlugin, "llava_next": LlavaNextPlugin, "llava_next_video": LlavaNextVideoPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 5768cf7b..64911c58 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -583,6 +583,22 @@ _register_template( ) +_register_template( + name="cpm_o", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), + format_separator=EmptyFormatter(slots=["\n"]), + default_system="You are a helpful assistant.", + stop_words=["<|im_end|>"], + mm_plugin=get_mm_plugin(name="cpm_o", image_token="", video_token="