[v1] add pair data converter (#9360)

Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
魅影 2025-11-06 14:05:58 +08:00 committed by GitHub
parent bd30c0003b
commit bd24350cbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 149 additions and 1 deletions

10
data/v1_dpo_demo.jsonl Normal file

File diff suppressed because one or more lines are too long

4
data/v1_dpo_demo.yaml Normal file
View File

@ -0,0 +1,4 @@
dpo_zh_demo:
hf_hub_url: HuggingFaceH4/orca_dpo_pairs
split: train_prefs
converter: pair

View File

@ -17,7 +17,7 @@ from typing import Callable, TypedDict
from typing_extensions import NotRequired from typing_extensions import NotRequired
from ...extras.types import Sample, SFTSample from ...extras.types import DPOSample, Sample, SFTSample
class AlpacaSample(TypedDict, total=False): class AlpacaSample(TypedDict, total=False):
@ -27,6 +27,12 @@ class AlpacaSample(TypedDict, total=False):
output: NotRequired[str] output: NotRequired[str]
class PairSample(TypedDict, total=False):
prompt: NotRequired[str]
chosen: NotRequired[list[dict]]
rejected: NotRequired[list[dict]]
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample. """Convert Alpaca sample to SFT sample.
@ -61,8 +67,87 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages} return {"messages": messages}
def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to standard DPO sample.
Args:
raw_sample (PairSample): pair sample with prompt, chosen, rejected fields.
see raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
Returns:
DPOSample: DPO sample with chosen_messages and rejected_messages.
see the standard DPO sample at: https://huggingface.co/datasets/frozenleaves/v1-dpo-demo/raw/main/v1-dpo-demo.jsonl
"""
chosen_messages = []
assert "chosen" in raw_sample, "chosen field is required in pair sample."
assert "rejected" in raw_sample, "rejected field is required in pair sample."
assert isinstance(raw_sample["chosen"], list) and isinstance(raw_sample["rejected"], list), (
"chosen and rejected field should be a list[dict], or you may need to implement your custom converter."
)
if "chosen" in raw_sample:
value = raw_sample.get("chosen", "")
for item in value:
if item.get("role", "") == "system":
chosen_messages.append(
{
"role": "system",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "user":
chosen_messages.append(
{
"role": "user",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "assistant":
chosen_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 1.0,
}
)
rejected_messages = []
if "rejected" in raw_sample:
value = raw_sample.get("rejected", "")
for item in value:
if item.get("role", "") == "system":
rejected_messages.append(
{
"role": "system",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "user":
rejected_messages.append(
{
"role": "user",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 0.0,
}
)
if item.get("role", "") == "assistant":
rejected_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "value": item.get("content", "")}],
"loss_weight": 1.0,
}
)
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
CONVERTERS = { CONVERTERS = {
"alpaca": alpaca_converter, "alpaca": alpaca_converter,
"pair": pair_converter,
} }

View File

@ -48,5 +48,54 @@ def test_alpaca_converter(num_samples: int):
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data} assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
@pytest.mark.parametrize("num_samples", [16])
def test_pair_converter(num_samples: int):
data_args = DataArguments(dataset="frozenleaves/tiny-dpo/dataset_info.yaml")
data_engine = DataEngine(data_args)
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
indexes = random.choices(range(len(data_engine)), k=num_samples)
for index in indexes:
print(data_engine[index])
print(original_data[index])
expected_data = {
"chosen_messages": [
{
"role": "system",
"content": [{"type": "text", "value": original_data[index]["chosen"][0]["content"]}],
"loss_weight": 0.0,
},
{
"role": "user",
"content": [{"type": "text", "value": original_data[index]["chosen"][1]["content"]}],
"loss_weight": 0.0,
},
{
"role": "assistant",
"content": [{"type": "text", "value": original_data[index]["chosen"][2]["content"]}],
"loss_weight": 1.0,
},
],
"rejected_messages": [
{
"role": "system",
"content": [{"type": "text", "value": original_data[index]["rejected"][0]["content"]}],
"loss_weight": 0.0,
},
{
"role": "user",
"content": [{"type": "text", "value": original_data[index]["rejected"][1]["content"]}],
"loss_weight": 0.0,
},
{
"role": "assistant",
"content": [{"type": "text", "value": original_data[index]["rejected"][2]["content"]}],
"loss_weight": 1.0,
},
],
}
assert data_engine[index] == {"_dataset_name": "dpo_zh_demo", **expected_data}
if __name__ == "__main__": if __name__ == "__main__":
test_alpaca_converter(1) test_alpaca_converter(1)
test_pair_converter(1)