mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
lint
Former-commit-id: da9e4ddd26ebd6e7eb266aa0bef7505465a6b119
This commit is contained in:
parent
a6aeb98af6
commit
118ffe50e3
@ -753,12 +753,14 @@ class MllamaPlugin(BasePlugin):
|
|||||||
cross_attention_token_mask = [
|
cross_attention_token_mask = [
|
||||||
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
||||||
]
|
]
|
||||||
mm_inputs["cross_attention_mask"] = torch.tensor(convert_sparse_cross_attention_mask_to_dense(
|
mm_inputs["cross_attention_mask"] = torch.from_numpy(
|
||||||
cross_attention_token_mask,
|
convert_sparse_cross_attention_mask_to_dense(
|
||||||
num_tiles=num_tiles,
|
cross_attention_token_mask,
|
||||||
max_num_tiles=max_image_tiles,
|
num_tiles=num_tiles,
|
||||||
length=max(len(input_ids) for input_ids in batch_ids),
|
max_num_tiles=max_image_tiles,
|
||||||
))
|
length=max(len(input_ids) for input_ids in batch_ids),
|
||||||
|
)
|
||||||
|
)
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user