diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 42b4f565..8fa6f0dd 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -99,8 +99,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features: Dict[str, "torch.Tensor"] = super().__call__(features) features.update(mm_inputs) - if features.get("pixel_values") is not None and isinstance(features["pixel_values"], list): - features = features.data + if isinstance(features.get("pixel_values"), list): # for pixtral inputs + features = features.data # use default_collate() instead of BatchEncoding.to() return features