diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index b76cf4f3..383a1271 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -753,12 +753,14 @@ class MllamaPlugin(BasePlugin): cross_attention_token_mask = [ get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids ] - mm_inputs["cross_attention_mask"] = torch.tensor(convert_sparse_cross_attention_mask_to_dense( - cross_attention_token_mask, - num_tiles=num_tiles, - max_num_tiles=max_image_tiles, - length=max(len(input_ids) for input_ids in batch_ids), - )) + mm_inputs["cross_attention_mask"] = torch.from_numpy( + convert_sparse_cross_attention_mask_to_dense( + cross_attention_token_mask, + num_tiles=num_tiles, + max_num_tiles=max_image_tiles, + length=max(len(input_ids) for input_ids in batch_ids), + ) + ) return mm_inputs