From c0ffe68745eda1f571e07544c696cf0a53d6a63e Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 25 Nov 2024 22:22:06 +0800 Subject: [PATCH] fix #6139 Former-commit-id: d87e16cf5c46dadbfcda7b8ac8edfef6a012f97f --- src/llamafactory/data/mm_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 248cbd38..b76cf4f3 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -753,12 +753,12 @@ class MllamaPlugin(BasePlugin): cross_attention_token_mask = [ get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids ] - mm_inputs["cross_attention_mask"] = convert_sparse_cross_attention_mask_to_dense( + mm_inputs["cross_attention_mask"] = torch.tensor(convert_sparse_cross_attention_mask_to_dense( cross_attention_token_mask, num_tiles=num_tiles, max_num_tiles=max_image_tiles, length=max(len(input_ids) for input_ids in batch_ids), - ) + )) return mm_inputs