From 1bb3d17d9ebcd96920c6ea52b1b21c2f26b436ce Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sun, 9 Feb 2025 22:42:25 +0800 Subject: [PATCH] [data] fix mllama collator (#6874) Former-commit-id: b68199db274a53d5916179e1aaf9722fd94fa2dc --- src/llamafactory/data/collator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 1f665904..9435bed4 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -121,7 +121,6 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ) batch_images = fake_images batch_imglens[0] = 1 - batch_input_ids[0] = features[0]["input_ids"] if ( self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0 @@ -137,7 +136,6 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ) batch_audios = fake_audios batch_audlens[0] = 1 - batch_input_ids[0] = features[0]["input_ids"] if fake_input_ids is not None: if self.tokenizer.padding_side == "right": @@ -149,6 +147,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"] features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"] + batch_input_ids[0] = features[0]["input_ids"] + mm_inputs = self.template.mm_plugin.get_mm_inputs( batch_images, batch_videos,