Former-commit-id: a650e114e907278ece188922467c2514de544eeb
This commit is contained in:
fzc8578 2025-01-11 01:10:24 +08:00
parent 08e8499a98
commit 62c12a133e

View File

@ -152,14 +152,11 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features.update(mm_inputs) features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to() features = features.data # use default_collate() instead of BatchEncoding.to()
if "image_bound" in features: # for minicpmv inputs if "image_bound" in features: # for minicpmv inputs
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]] features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0) features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0)
features["labels"] = pad_sequence(features["labels"], batch_first=True, padding_value=-100) new_features = {"data": features}
features["attention_mask"] = pad_sequence(features["attention_mask"], batch_first=True, padding_value=0)
new_features = {}
new_features.update({"data": features})
new_features.update(features) new_features.update(features)
features = new_features features = new_features