mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-07 20:26:00 +08:00
@@ -1,29 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
features = [
|
||||
{
|
||||
"input_ids": feature["prompt_ids"] + feature[key],
|
||||
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])),
|
||||
}
|
||||
for key in ("chosen_ids", "rejected_ids")
|
||||
for feature in features
|
||||
]
|
||||
return super().__call__(features)
|
||||
Reference in New Issue
Block a user