diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 4de6bc2c..838c0719 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from transformers import DataCollatorForSeq2Seq from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER +from ..extras.misc import get_current_device from ..extras.packages import is_pillow_available @@ -63,17 +64,31 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype ``` where `o` equals to `0.0`, `x` equals to `min_dtype`. """ - bsz, seq_len = attention_mask_with_indices.size() + _, seq_len = attention_mask_with_indices.size() + + # Move to compute device if the source is CPU. + source_device = attention_mask_with_indices.device + compute_device = get_current_device() if source_device.type == "cpu" else source_device + if compute_device != source_device: + attention_mask_with_indices = attention_mask_with_indices.to(compute_device) + min_dtype = torch.finfo(dtype).min - expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len) - # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one - padding_mask = torch.where(expanded_mask != 0, 1, 0) - # Create a block-diagonal mask. - attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask - # Use the lower triangular mask to zero out the upper triangular part - attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long)) + zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device) + + # Create a non-padding mask. + non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2) + # Create indices for comparison. + indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len] + indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1] + # Create a lower triangular mask. + tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device)) + attention_mask_4d = (indices == indices_t) & non_padding & tril_mask # Invert the attention mask. - attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype) + attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype) + + # Move back to original device if needed. + if compute_device != source_device: + attention_mask_4d = attention_mask_4d.to(source_device) return attention_mask_4d