This commit is contained in:
hiyouga
2024-08-30 03:21:50 +08:00
parent 8b588c7224
commit bee1bd43b9
8 changed files with 24 additions and 13 deletions

View File

@@ -68,7 +68,7 @@ class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
image_grid_thw = None
image_grid_thw = None # TODO: better handle various VLMs
if "image_grid_thw" in features[0]:
image_grid_thw_list = [
torch.Tensor(feature["image_grid_thw"]).long()