From 118ffe50e3825e985e954a1fa3520218dd974cec Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 25 Nov 2024 22:55:56 +0800 Subject: [PATCH] lint Former-commit-id: da9e4ddd26ebd6e7eb266aa0bef7505465a6b119 --- src/llamafactory/data/mm_plugin.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index b76cf4f3..383a1271 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -753,12 +753,14 @@ class MllamaPlugin(BasePlugin): cross_attention_token_mask = [ get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids ] - mm_inputs["cross_attention_mask"] = torch.tensor(convert_sparse_cross_attention_mask_to_dense( - cross_attention_token_mask, - num_tiles=num_tiles, - max_num_tiles=max_image_tiles, - length=max(len(input_ids) for input_ids in batch_ids), - )) + mm_inputs["cross_attention_mask"] = torch.from_numpy( + convert_sparse_cross_attention_mask_to_dense( + cross_attention_token_mask, + num_tiles=num_tiles, + max_num_tiles=max_image_tiles, + length=max(len(input_ids) for input_ids in batch_ids), + ) + ) return mm_inputs