From e8083f8f3f0326494c08a41e6792d1fce9a8138b Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 25 Nov 2024 22:55:56 +0800 Subject: [PATCH] lint Former-commit-id: 57c3cf1f498d5ffafdc8c06e0f8713f8ff77de81 --- src/llamafactory/data/mm_plugin.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index b76cf4f3..383a1271 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -753,12 +753,14 @@ class MllamaPlugin(BasePlugin): cross_attention_token_mask = [ get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids ] - mm_inputs["cross_attention_mask"] = torch.tensor(convert_sparse_cross_attention_mask_to_dense( - cross_attention_token_mask, - num_tiles=num_tiles, - max_num_tiles=max_image_tiles, - length=max(len(input_ids) for input_ids in batch_ids), - )) + mm_inputs["cross_attention_mask"] = torch.from_numpy( + convert_sparse_cross_attention_mask_to_dense( + cross_attention_token_mask, + num_tiles=num_tiles, + max_num_tiles=max_image_tiles, + length=max(len(input_ids) for input_ids in batch_ids), + ) + ) return mm_inputs