diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 1f665904..9435bed4 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -121,7 +121,6 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ) batch_images = fake_images batch_imglens[0] = 1 - batch_input_ids[0] = features[0]["input_ids"] if ( self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0 @@ -137,7 +136,6 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ) batch_audios = fake_audios batch_audlens[0] = 1 - batch_input_ids[0] = features[0]["input_ids"] if fake_input_ids is not None: if self.tokenizer.padding_side == "right": @@ -149,6 +147,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"] features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"] + batch_input_ids[0] = features[0]["input_ids"] + mm_inputs = self.template.mm_plugin.get_mm_inputs( batch_images, batch_videos,