diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 5e599653..cf60d944 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -152,7 +152,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features.update(mm_inputs) if isinstance(features.get("pixel_values"), list): # for pixtral inputs features = features.data # use default_collate() instead of BatchEncoding.to() - + if "image_bound" in features: # for minicpmv inputs features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]] features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0)