simplify readme

Former-commit-id: 0da6ec2d516326fe9c7583ba71cd1778eb838178
This commit is contained in:
hiyouga
2024-04-02 20:07:43 +08:00
parent 117b67ea30
commit b12176d818
24 changed files with 244 additions and 890 deletions

View File

@@ -1,29 +0,0 @@
from dataclasses import dataclass
from typing import Any, Dict, Sequence
import torch
from transformers import DataCollatorWithPadding
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
features = [
{
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])),
}
for key in ("chosen_ids", "rejected_ids")
for feature in features
]
return super().__call__(features)