diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index a54970d31..677045cd0 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -380,6 +380,19 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): for i, feature in enumerate(features): feature["token_type_ids"] = token_type_ids[i] + if "mm_token_type_ids" in mm_inputs: # need tensor-like for gemma4 + mm_token_type_ids = mm_inputs.pop("mm_token_type_ids") + max_len = max(len(ids) for ids in mm_token_type_ids) + padded = [] + for ids in mm_token_type_ids: + pad_len = max_len - len(ids) + if self.tokenizer.padding_side == "right": + padded.append(ids + [0] * pad_len) + else: + padded.append([0] * pad_len + ids) + + mm_inputs["mm_token_type_ids"] = torch.tensor(padded, dtype=torch.long) + features: dict[str, torch.Tensor] = super().__call__(features) bsz, seq_len = features["input_ids"].shape[:2]