mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-05 05:02:50 +08:00
82 lines
2.9 KiB
Python
82 lines
2.9 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any, Dict, Sequence
|
|
|
|
import torch
|
|
from transformers import DataCollatorForSeq2Seq
|
|
|
|
|
|
@dataclass
|
|
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|
r"""
|
|
Data collator for pairwise data.
|
|
"""
|
|
|
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
|
r"""
|
|
Pads batched data to the longest sequence in the batch.
|
|
|
|
We generate 2 * n examples where the first n examples represent chosen examples and
|
|
the last n examples represent rejected examples.
|
|
"""
|
|
concatenated_features = []
|
|
for key in ("chosen", "rejected"):
|
|
for feature in features:
|
|
target_feature = {
|
|
"input_ids": feature["{}_input_ids".format(key)],
|
|
"attention_mask": feature["{}_attention_mask".format(key)],
|
|
"labels": feature["{}_labels".format(key)],
|
|
}
|
|
if "pixel_values" in feature:
|
|
target_feature["pixel_values"] = feature["pixel_values"]
|
|
|
|
if "{}_token_type_ids".format(key) in feature:
|
|
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
|
|
|
concatenated_features.append(target_feature)
|
|
|
|
return super().__call__(concatenated_features)
|
|
|
|
|
|
@dataclass
|
|
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|
r"""
|
|
Data collator for KTO data.
|
|
"""
|
|
|
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
|
target_features = []
|
|
kl_features = []
|
|
kto_tags = []
|
|
for feature in features:
|
|
target_feature = {
|
|
"input_ids": feature["input_ids"],
|
|
"attention_mask": feature["attention_mask"],
|
|
"labels": feature["labels"],
|
|
}
|
|
kl_feature = {
|
|
"input_ids": feature["kl_input_ids"],
|
|
"attention_mask": feature["kl_attention_mask"],
|
|
"labels": feature["kl_labels"],
|
|
}
|
|
if "pixel_values" in feature:
|
|
target_feature["pixel_values"] = feature["pixel_values"]
|
|
|
|
if "token_type_ids" in feature:
|
|
target_feature["token_type_ids"] = feature["token_type_ids"]
|
|
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
|
|
|
target_features.append(target_feature)
|
|
kl_features.append(kl_feature)
|
|
kto_tags.append(feature["kto_tags"])
|
|
|
|
batch = super().__call__(target_features)
|
|
kl_batch = super().__call__(kl_features)
|
|
batch["kl_input_ids"] = kl_batch["input_ids"]
|
|
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
|
batch["kl_labels"] = kl_batch["labels"]
|
|
if "token_type_ids" in batch:
|
|
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
|
|
|
batch["kto_tags"] = torch.tensor(kto_tags)
|
|
return batch
|