mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
184 lines
7.1 KiB
Python
184 lines
7.1 KiB
Python
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
|
|
#
|
|
# This code is inspired by the OpenAccess AI Collective's axolotl library.
|
|
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Literal, Sequence
|
|
|
|
import torch
|
|
from transformers import DataCollatorForSeq2Seq
|
|
|
|
|
|
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
|
r"""
|
|
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
|
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
|
|
|
e.g.
|
|
```python
|
|
# input
|
|
[[1, 1, 2, 2, 2, 0]]
|
|
# output
|
|
[
|
|
[
|
|
[
|
|
[o, x, x, x, x, x],
|
|
[o, o, x, x, x, x],
|
|
[x, x, o, x, x, x],
|
|
[x, x, o, o, x, x],
|
|
[x, x, o, o, o, x],
|
|
[x, x, x, x, x, x],
|
|
]
|
|
]
|
|
]
|
|
```
|
|
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
|
"""
|
|
bsz, seq_len = attention_mask_with_indices.size()
|
|
min_dtype = torch.finfo(dtype).min
|
|
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
|
padding_mask = torch.where(expanded_mask != 0, 1, 0)
|
|
# Create a block-diagonal mask.
|
|
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
|
|
# Use the lower triangular mask to zero out the upper triangular part
|
|
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
|
|
# Invert the attention mask.
|
|
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
|
|
return attention_mask_4d
|
|
|
|
|
|
@dataclass
|
|
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
|
r"""
|
|
Data collator for 4d attention mask.
|
|
"""
|
|
|
|
block_diag_attn: bool = False
|
|
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
|
compute_dtype: "torch.dtype" = torch.float32
|
|
|
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
|
image_grid_thw = None
|
|
if "image_grid_thw" in features[0]:
|
|
image_grid_thw_list = [
|
|
torch.Tensor(feature["image_grid_thw"]).long()
|
|
for feature in features
|
|
if feature["image_grid_thw"][0][0] > 0
|
|
]
|
|
pixel_values_list = [
|
|
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
|
|
]
|
|
if image_grid_thw_list:
|
|
image_grid_thw = torch.cat(image_grid_thw_list, 0)
|
|
else:
|
|
# Handle the case where the list is empty, for example:
|
|
image_grid_thw = None
|
|
if pixel_values_list:
|
|
pixel_values = torch.cat(pixel_values_list, 0)
|
|
else:
|
|
# Handle the case where the list is empty, for example:
|
|
pixel_values = None
|
|
features = [
|
|
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
|
|
for feature in features
|
|
]
|
|
|
|
features = super().__call__(features)
|
|
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
|
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
|
if image_grid_thw is not None:
|
|
features["image_grid_thw"] = image_grid_thw
|
|
features["pixel_values"] = pixel_values
|
|
|
|
return features
|
|
|
|
|
|
@dataclass
|
|
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|
r"""
|
|
Data collator for pairwise data.
|
|
"""
|
|
|
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
|
r"""
|
|
Pads batched data to the longest sequence in the batch.
|
|
|
|
We generate 2 * n examples where the first n examples represent chosen examples and
|
|
the last n examples represent rejected examples.
|
|
"""
|
|
concatenated_features = []
|
|
for key in ("chosen", "rejected"):
|
|
for feature in features:
|
|
target_feature = {
|
|
"input_ids": feature["{}_input_ids".format(key)],
|
|
"attention_mask": feature["{}_attention_mask".format(key)],
|
|
"labels": feature["{}_labels".format(key)],
|
|
}
|
|
if "pixel_values" in feature:
|
|
target_feature["pixel_values"] = feature["pixel_values"]
|
|
|
|
if "{}_token_type_ids".format(key) in feature:
|
|
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
|
|
|
concatenated_features.append(target_feature)
|
|
|
|
return super().__call__(concatenated_features)
|
|
|
|
|
|
@dataclass
|
|
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|
r"""
|
|
Data collator for KTO data.
|
|
"""
|
|
|
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
|
target_features = []
|
|
kl_features = []
|
|
kto_tags = []
|
|
for feature in features:
|
|
target_feature = {
|
|
"input_ids": feature["input_ids"],
|
|
"attention_mask": feature["attention_mask"],
|
|
"labels": feature["labels"],
|
|
}
|
|
kl_feature = {
|
|
"input_ids": feature["kl_input_ids"],
|
|
"attention_mask": feature["kl_attention_mask"],
|
|
"labels": feature["kl_labels"],
|
|
}
|
|
if "pixel_values" in feature:
|
|
target_feature["pixel_values"] = feature["pixel_values"]
|
|
|
|
if "token_type_ids" in feature:
|
|
target_feature["token_type_ids"] = feature["token_type_ids"]
|
|
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
|
|
|
target_features.append(target_feature)
|
|
kl_features.append(kl_feature)
|
|
kto_tags.append(feature["kto_tags"])
|
|
|
|
batch = super().__call__(target_features)
|
|
kl_batch = super().__call__(kl_features)
|
|
batch["kl_input_ids"] = kl_batch["input_ids"]
|
|
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
|
batch["kl_labels"] = kl_batch["labels"]
|
|
if "token_type_ids" in batch:
|
|
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
|
|
|
batch["kto_tags"] = torch.tensor(kto_tags)
|
|
return batch
|