From 0fb50f9c88b1d476cb83569ad0bf45a9bd950b1c Mon Sep 17 00:00:00 2001 From: fzc8578 <1428195643@qq.com> Date: Fri, 10 Jan 2025 23:29:06 +0800 Subject: [PATCH] add some Former-commit-id: 771cc802941cf1953b32e5102c817c6a3090b5ce --- src/llamafactory/data/collator.py | 14 +++++++--- src/llamafactory/data/mm_plugin.py | 44 +++++++++++++++++++++--------- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 5aea4c61..036e1a79 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -19,8 +19,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence import torch -from torch.nn.utils.rnn import pad_sequence import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence from transformers import DataCollatorForSeq2Seq from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER @@ -106,7 +106,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): batch_vidlens.append(len(videos)) batch_input_ids.append(feature["input_ids"]) - if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case + if ( + self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0 + ): # avoid process hanging in zero3/fsdp case fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}] fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))] fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor) @@ -157,10 +159,14 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if "image_bound" in features: # for minicpmv inputs features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]] - features["input_ids"] = pad(features["input_ids"],) + features["input_ids"] = pad( + features["input_ids"], + ) features["position_ids"] = pad(features["position_ids"]) features["labels"] = pad(features["labels"], padding_value=-100) - features["attention_mask"] = pad(features["attention_mask"],) + features["attention_mask"] = pad( + features["attention_mask"], + ) new_features = {} new_features.update({"data": features}) new_features.update(features) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index a40f78fc..55db4626 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -265,8 +265,19 @@ class CpmOPlugin(BasePlugin): ) -> List[Dict[str, str]]: self._validate_input(images, videos) num_image_tokens = 0 + num_video_tokens = 0 messages = deepcopy(messages) image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + mm_inputs = {} + + if len(videos) != 0: + assert len(images) == 0, "Only support video and image sft seperately" + max_slice_nums = 2 + use_image_id = False + mm_inputs = self._get_mm_inputs([], videos, processor) + else: + max_slice_nums = image_processor.max_slice_nums + use_image_id = image_processor.use_image_id for message in messages: content = message["content"] @@ -274,15 +285,21 @@ class CpmOPlugin(BasePlugin): num_image_tokens += 1 content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + while VIDEO_PLACEHOLDER in content: + num_video_tokens += 1 + content = content.replace( + VIDEO_PLACEHOLDER, "{{image}}" * len(mm_inputs["pixel_values"][num_video_tokens - 1]), 1 + ) + message["content"] = content.replace("{{image}}", "(./)") if num_image_tokens > 0: - mm_inputs = self._get_mm_inputs(images, videos, processor) + mm_inputs = self._get_mm_inputs(images, [], processor) + if mm_inputs: pattern = "(./)" - images, image_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"] + image_sizes = mm_inputs["image_sizes"] - image_index = 0 for index, message in enumerate(messages): text = message["content"] image_tags = re.findall(pattern, text) @@ -293,19 +310,21 @@ class CpmOPlugin(BasePlugin): final_text + text_chunks[i] + image_processor.get_slice_image_placeholder( - image_sizes[image_index][i], + image_sizes[0][i], i, - image_processor.max_slice_nums, - image_processor.use_image_id, + max_slice_nums, + use_image_id, ) ) - image_index += 1 final_text += text_chunks[-1] messages[index]["content"] = final_text if len(images) != num_image_tokens: raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + if len(videos) != num_video_tokens: + raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") + return messages @override @@ -333,7 +352,7 @@ class CpmOPlugin(BasePlugin): new_images.append(images[idx : idx + valid_image_nums]) idx += valid_image_nums images = new_images - + image_inputs = image_processor( images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt" ) @@ -346,6 +365,8 @@ class CpmOPlugin(BasePlugin): video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 64), ) + video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") + mm_inputs.update(video_inputs) return mm_inputs @@ -380,12 +401,9 @@ class CpmOPlugin(BasePlugin): ] ) image_bounds_list.append(image_bounds) + mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls) - mm_inputs.update( - { - "image_bound": image_bounds_list, - } - ) + mm_inputs.update({"image_bound": image_bounds_list}) return mm_inputs