mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
[v1] add pair data converter (#9360)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
@@ -17,7 +17,7 @@ from typing import Callable, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from ...extras.types import Sample, SFTSample
|
||||
from ...extras.types import DPOSample, Sample, SFTSample
|
||||
|
||||
|
||||
class AlpacaSample(TypedDict, total=False):
|
||||
@@ -27,6 +27,12 @@ class AlpacaSample(TypedDict, total=False):
|
||||
output: NotRequired[str]
|
||||
|
||||
|
||||
class PairSample(TypedDict, total=False):
|
||||
prompt: NotRequired[str]
|
||||
chosen: NotRequired[list[dict]]
|
||||
rejected: NotRequired[list[dict]]
|
||||
|
||||
|
||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
"""Convert Alpaca sample to SFT sample.
|
||||
|
||||
@@ -61,8 +67,87 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
"""Convert Pair sample to standard DPO sample.
|
||||
|
||||
Args:
|
||||
raw_sample (PairSample): pair sample with prompt, chosen, rejected fields.
|
||||
see raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
|
||||
|
||||
Returns:
|
||||
DPOSample: DPO sample with chosen_messages and rejected_messages.
|
||||
see the standard DPO sample at: https://huggingface.co/datasets/frozenleaves/v1-dpo-demo/raw/main/v1-dpo-demo.jsonl
|
||||
"""
|
||||
chosen_messages = []
|
||||
assert "chosen" in raw_sample, "chosen field is required in pair sample."
|
||||
assert "rejected" in raw_sample, "rejected field is required in pair sample."
|
||||
assert isinstance(raw_sample["chosen"], list) and isinstance(raw_sample["rejected"], list), (
|
||||
"chosen and rejected field should be a list[dict], or you may need to implement your custom converter."
|
||||
)
|
||||
|
||||
if "chosen" in raw_sample:
|
||||
value = raw_sample.get("chosen", "")
|
||||
for item in value:
|
||||
if item.get("role", "") == "system":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "user":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "assistant":
|
||||
chosen_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
|
||||
rejected_messages = []
|
||||
if "rejected" in raw_sample:
|
||||
value = raw_sample.get("rejected", "")
|
||||
for item in value:
|
||||
if item.get("role", "") == "system":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "user":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 0.0,
|
||||
}
|
||||
)
|
||||
if item.get("role", "") == "assistant":
|
||||
rejected_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": item.get("content", "")}],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
|
||||
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
|
||||
|
||||
|
||||
CONVERTERS = {
|
||||
"alpaca": alpaca_converter,
|
||||
"pair": pair_converter,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user