Former-commit-id: d87e16cf5c46dadbfcda7b8ac8edfef6a012f97f
This commit is contained in:
hoshi-hiyouga 2024-11-25 22:22:06 +08:00 committed by GitHub
parent 1a8c26a7d9
commit c0ffe68745

View File

@ -753,12 +753,12 @@ class MllamaPlugin(BasePlugin):
cross_attention_token_mask = [ cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
] ]
mm_inputs["cross_attention_mask"] = convert_sparse_cross_attention_mask_to_dense( mm_inputs["cross_attention_mask"] = torch.tensor(convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask, cross_attention_token_mask,
num_tiles=num_tiles, num_tiles=num_tiles,
max_num_tiles=max_image_tiles, max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids), length=max(len(input_ids) for input_ids in batch_ids),
) ))
return mm_inputs return mm_inputs