diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 5e67899b..968d7018 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -93,7 +93,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): else: image_grid_thw.append(torch.tensor(feature["image_grid_thw"], dtype=torch.long)) - extra_features["image_grid_thw"] = torch.cat(pixel_values, dim=0) + extra_features["image_grid_thw"] = torch.cat(image_grid_thw, dim=0) if extra_features["image_grid_thw"].numel() == 0: extra_features["image_grid_thw"] = None