mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
Merge pull request #6140 from hiyouga/hiyouga/fix_mllama
[data] fix mllama plugin Former-commit-id: 3924a3d6e9d761dd51eca92afed7f299be71e42d
This commit is contained in:
commit
a6aeb98af6
@ -753,12 +753,12 @@ class MllamaPlugin(BasePlugin):
|
||||
cross_attention_token_mask = [
|
||||
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
||||
]
|
||||
mm_inputs["cross_attention_mask"] = convert_sparse_cross_attention_mask_to_dense(
|
||||
mm_inputs["cross_attention_mask"] = torch.tensor(convert_sparse_cross_attention_mask_to_dense(
|
||||
cross_attention_token_mask,
|
||||
num_tiles=num_tiles,
|
||||
max_num_tiles=max_image_tiles,
|
||||
length=max(len(input_ids) for input_ids in batch_ids),
|
||||
)
|
||||
))
|
||||
return mm_inputs
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user