mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-07 03:12:13 +08:00
[v1] add pair data converter (#9360)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
parent
bd30c0003b
commit
bd24350cbf
10
data/v1_dpo_demo.jsonl
Normal file
10
data/v1_dpo_demo.jsonl
Normal file
File diff suppressed because one or more lines are too long
4
data/v1_dpo_demo.yaml
Normal file
4
data/v1_dpo_demo.yaml
Normal file
@ -0,0 +1,4 @@
|
||||
dpo_zh_demo:
|
||||
hf_hub_url: HuggingFaceH4/orca_dpo_pairs
|
||||
split: train_prefs
|
||||
converter: pair
|
||||
@ -17,7 +17,7 @@ from typing import Callable, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from ...extras.types import Sample, SFTSample
|
||||
from ...extras.types import DPOSample, Sample, SFTSample
|
||||
|
||||
|
||||
class AlpacaSample(TypedDict, total=False):
|
||||
@ -27,6 +27,12 @@ class AlpacaSample(TypedDict, total=False):
|
||||
output: NotRequired[str]
|
||||
|
||||
|
||||
class PairSample(TypedDict, total=False):
|
||||
prompt: NotRequired[str]
|
||||
chosen: NotRequired[list[dict]]
|
||||
rejected: NotRequired[list[dict]]
|
||||
|
||||
|
||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
"""Convert Alpaca sample to SFT sample.
|
||||
|
||||
@ -61,8 +67,87 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
"""Convert Pair sample to standard DPO sample.
|
||||
|
||||
Args:
|
||||
raw_sample (PairSample): pair sample with prompt, chosen, rejected fields.
|
||||
see raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
|
||||
|
||||
Returns:
|
||||
DPOSample: DPO sample with chosen_messages and rejected_messages.
|
||||
see the standard DPO sample at: https://huggingface.co/datasets/frozenleaves/v1-dpo-demo/raw/main/v1-dpo-demo.jsonl
|
||||
"""
|
||||
chosen_messages = []
|
||||
assert "chosen" in raw_sample, "chosen field is required in pair sample."
|
||||
assert "rejected" in raw_sample, "rejected field is required in pair sample."
|
||||
assert isinstance(raw_sample["chosen"], list) and isinstance(raw_sample["rejected"], list), (
|
||||
"chosen and rejected field should be a list[dict], or you may need to implement your custom converter."
|
||||
)
|
||||
|
||||
if "chosen" in raw_sample:
|
||||
value = raw_sample.get("chosen", "")
|
||||
for item in value:
|
||||
if item.get("role", "") == "system":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "user":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "assistant":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
|
||||
rejected_messages = []
|
||||
if "rejected" in raw_sample:
|
||||
value = raw_sample.get("rejected", "")
|
||||
for item in value:
|
||||
if item.get("role", "") == "system":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "user":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "assistant":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
|
||||
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
|
||||
|
||||
|
||||
CONVERTERS = {
|
||||
"alpaca": alpaca_converter,
|
||||
"pair": pair_converter,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -48,5 +48,54 @@ def test_alpaca_converter(num_samples: int):
|
||||
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_pair_converter(num_samples: int):
|
||||
data_args = DataArguments(dataset="frozenleaves/tiny-dpo/dataset_info.yaml")
|
||||
data_engine = DataEngine(data_args)
|
||||
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
|
||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||
for index in indexes:
|
||||
print(data_engine[index])
|
||||
print(original_data[index])
|
||||
expected_data = {
|
||||
"chosen_messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": original_data[index]["chosen"][0]["content"]}],
|
||||
"loss_weight": 0.0,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": original_data[index]["chosen"][1]["content"]}],
|
||||
"loss_weight": 0.0,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": original_data[index]["chosen"][2]["content"]}],
|
||||
"loss_weight": 1.0,
|
||||
},
|
||||
],
|
||||
"rejected_messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": original_data[index]["rejected"][0]["content"]}],
|
||||
"loss_weight": 0.0,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": original_data[index]["rejected"][1]["content"]}],
|
||||
"loss_weight": 0.0,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": original_data[index]["rejected"][2]["content"]}],
|
||||
"loss_weight": 1.0,
|
||||
},
|
||||
],
|
||||
}
|
||||
assert data_engine[index] == {"_dataset_name": "dpo_zh_demo", **expected_data}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_alpaca_converter(1)
|
||||
test_pair_converter(1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user