diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py
index 5aea4c61..036e1a79 100644
--- a/src/llamafactory/data/collator.py
+++ b/src/llamafactory/data/collator.py
@@ -19,8 +19,8 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch
-from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_sequence
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
@@ -106,7 +106,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_vidlens.append(len(videos))
batch_input_ids.append(feature["input_ids"])
- if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case
+ if (
+ self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
+ ): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
@@ -157,10 +159,14 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if "image_bound" in features: # for minicpmv inputs
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
- features["input_ids"] = pad(features["input_ids"],)
+ features["input_ids"] = pad(
+ features["input_ids"],
+ )
features["position_ids"] = pad(features["position_ids"])
features["labels"] = pad(features["labels"], padding_value=-100)
- features["attention_mask"] = pad(features["attention_mask"],)
+ features["attention_mask"] = pad(
+ features["attention_mask"],
+ )
new_features = {}
new_features.update({"data": features})
new_features.update(features)
diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py
index a40f78fc..55db4626 100644
--- a/src/llamafactory/data/mm_plugin.py
+++ b/src/llamafactory/data/mm_plugin.py
@@ -265,8 +265,19 @@ class CpmOPlugin(BasePlugin):
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
+ num_video_tokens = 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
+ mm_inputs = {}
+
+ if len(videos) != 0:
+ assert len(images) == 0, "Only support video and image sft seperately"
+ max_slice_nums = 2
+ use_image_id = False
+ mm_inputs = self._get_mm_inputs([], videos, processor)
+ else:
+ max_slice_nums = image_processor.max_slice_nums
+ use_image_id = image_processor.use_image_id
for message in messages:
content = message["content"]
@@ -274,15 +285,21 @@ class CpmOPlugin(BasePlugin):
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
+ while VIDEO_PLACEHOLDER in content:
+ num_video_tokens += 1
+ content = content.replace(
+ VIDEO_PLACEHOLDER, "{{image}}" * len(mm_inputs["pixel_values"][num_video_tokens - 1]), 1
+ )
+
message["content"] = content.replace("{{image}}", "(./)")
if num_image_tokens > 0:
- mm_inputs = self._get_mm_inputs(images, videos, processor)
+ mm_inputs = self._get_mm_inputs(images, [], processor)
+ if mm_inputs:
pattern = "(./)"
- images, image_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"]
+ image_sizes = mm_inputs["image_sizes"]
- image_index = 0
for index, message in enumerate(messages):
text = message["content"]
image_tags = re.findall(pattern, text)
@@ -293,19 +310,21 @@ class CpmOPlugin(BasePlugin):
final_text
+ text_chunks[i]
+ image_processor.get_slice_image_placeholder(
- image_sizes[image_index][i],
+ image_sizes[0][i],
i,
- image_processor.max_slice_nums,
- image_processor.use_image_id,
+ max_slice_nums,
+ use_image_id,
)
)
- image_index += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
+ if len(videos) != num_video_tokens:
+ raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
+
return messages
@override
@@ -333,7 +352,7 @@ class CpmOPlugin(BasePlugin):
new_images.append(images[idx : idx + valid_image_nums])
idx += valid_image_nums
images = new_images
-
+
image_inputs = image_processor(
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
)
@@ -346,6 +365,8 @@ class CpmOPlugin(BasePlugin):
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
)
+ video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
+ mm_inputs.update(video_inputs)
return mm_inputs
@@ -380,12 +401,9 @@ class CpmOPlugin(BasePlugin):
]
)
image_bounds_list.append(image_bounds)
+
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
- mm_inputs.update(
- {
- "image_bound": image_bounds_list,
- }
- )
+ mm_inputs.update({"image_bound": image_bounds_list})
return mm_inputs