mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
[data] efficient 4d_attention_mask creation in neat_packing (#7272)
This commit is contained in:
parent
9ccfb97a2c
commit
d7d79f7e06
@ -24,6 +24,7 @@ import torch.nn.functional as F
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.misc import get_current_device
|
||||
from ..extras.packages import is_pillow_available
|
||||
|
||||
|
||||
@ -63,17 +64,31 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
```
|
||||
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
||||
"""
|
||||
bsz, seq_len = attention_mask_with_indices.size()
|
||||
_, seq_len = attention_mask_with_indices.size()
|
||||
|
||||
# Move to compute device if the source is CPU.
|
||||
source_device = attention_mask_with_indices.device
|
||||
compute_device = get_current_device() if source_device.type == "cpu" else source_device
|
||||
if compute_device != source_device:
|
||||
attention_mask_with_indices = attention_mask_with_indices.to(compute_device)
|
||||
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
padding_mask = torch.where(expanded_mask != 0, 1, 0)
|
||||
# Create a block-diagonal mask.
|
||||
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
|
||||
# Use the lower triangular mask to zero out the upper triangular part
|
||||
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
|
||||
zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device)
|
||||
|
||||
# Create a non-padding mask.
|
||||
non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
|
||||
# Create indices for comparison.
|
||||
indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
|
||||
indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
|
||||
# Create a lower triangular mask.
|
||||
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device))
|
||||
attention_mask_4d = (indices == indices_t) & non_padding & tril_mask
|
||||
# Invert the attention mask.
|
||||
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
|
||||
attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
|
||||
|
||||
# Move back to original device if needed.
|
||||
if compute_device != source_device:
|
||||
attention_mask_4d = attention_mask_4d.to(source_device)
|
||||
return attention_mask_4d
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user