diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 4fc9e803..f4f97feb 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence import torch +import torch.nn.functional as F from transformers import DataCollatorForSeq2Seq @@ -98,6 +99,12 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): feature["token_type_ids"] = token_type_ids[i] features: Dict[str, "torch.Tensor"] = super().__call__(features) + if "cross_attention_mask" in mm_inputs: # for mllama inputs + cross_attention_mask = mm_inputs.pop("cross_attention_mask") + seq_len = features["input_ids"].size(1) + orig_len = cross_attention_mask.size(1) + mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len)) + features.update(mm_inputs) if isinstance(features.get("pixel_values"), list): # for pixtral inputs features = features.data # use default_collate() instead of BatchEncoding.to() diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 383a1271..389e27b1 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -241,7 +241,7 @@ class BasePlugin: videos: a list of video inputs, shape (num_videos,) imglens: number of images in each sample, shape (batch_size,) vidlens: number of videos in each sample, shape (batch_size,) - batch_ids: input ids of samples, shape (batch_size, seq_len) + batch_ids: token ids of input samples, shape (batch_size, seq_len) processor: a processor for pre-processing images and videos """ self._validate_input(images, videos) @@ -760,7 +760,7 @@ class MllamaPlugin(BasePlugin): max_num_tiles=max_image_tiles, length=max(len(input_ids) for input_ids in batch_ids), ) - ) + ) # shape: (batch_size, length, max_num_images, max_num_tiles) return mm_inputs