from dataclasses import dataclass from typing import Any, Dict, Sequence import torch from transformers import DataCollatorForSeq2Seq @dataclass class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): r""" Data collator for pairwise data. """ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: r""" Pads batched data to the longest sequence in the batch. We generate 2 * n examples where the first n examples represent chosen examples and the last n examples represent rejected examples. """ concatenated_features = [] for key in ("chosen", "rejected"): for feature in features: target_feature = { "input_ids": feature["{}_input_ids".format(key)], "attention_mask": feature["{}_attention_mask".format(key)], "labels": feature["{}_labels".format(key)], } if "pixel_values" in feature: target_feature["pixel_values"] = feature["pixel_values"] if "{}_token_type_ids".format(key) in feature: target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] concatenated_features.append(target_feature) return super().__call__(concatenated_features) @dataclass class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): r""" Data collator for KTO data. """ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: target_features = [] kl_features = [] kto_tags = [] for feature in features: target_feature = { "input_ids": feature["input_ids"], "attention_mask": feature["attention_mask"], "labels": feature["labels"], } kl_feature = { "input_ids": feature["kl_input_ids"], "attention_mask": feature["kl_attention_mask"], "labels": feature["kl_labels"], } if "pixel_values" in feature: target_feature["pixel_values"] = feature["pixel_values"] if "token_type_ids" in feature: target_feature["token_type_ids"] = feature["token_type_ids"] kl_feature["token_type_ids"] = feature["kl_token_type_ids"] target_features.append(target_feature) kl_features.append(kl_feature) kto_tags.append(feature["kto_tags"]) batch = super().__call__(target_features) kl_batch = super().__call__(kl_features) batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_labels"] = kl_batch["labels"] if "token_type_ids" in batch: batch["kl_token_type_ids"] = kl_batch["token_type_ids"] batch["kto_tags"] = torch.tensor(kto_tags) return batch