support DPO training (2305.18290)

This commit is contained in:
hiyouga
2023-08-11 03:02:53 +08:00
parent 685dae4eff
commit 3ec4351cfd
34 changed files with 513 additions and 1027300 deletions

View File

@@ -1,8 +1,10 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, Sequence
from transformers import DataCollatorWithPadding
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r"""
Data collator for pairwise data.
@@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
the last n examples represent rejected examples.
"""
features = [
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
for key in ("accept_ids", "reject_ids") for feature in features
{
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
}
for key in ("chosen_ids", "rejected_ids") for feature in features
]
return super().__call__(features)