support full-parameter PPO

Former-commit-id: ce78303600
This commit is contained in:
hiyouga
2023-11-16 02:08:04 +08:00
parent 0c1fab84f1
commit f441932bd1
20 changed files with 288 additions and 145 deletions

View File

@@ -24,9 +24,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
def _info(self):
features = datasets.Features({
"instruction": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
})
return datasets.DatasetInfo(
description=_DESCRIPTION,
@@ -51,6 +49,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
with open(filepath, "r", encoding="utf-8") as f:
for key, row in enumerate(f):
data = json.loads(row)
conversations = []
prompt = data["instruction"].strip()
response = data["output"].strip()
@@ -58,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip()
prompt = prompt[:human_idx].strip()
history = []
conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query})
while prompt.rfind("Assistant:") != -1:
assist_idx = prompt.rfind("Assistant:")
@@ -66,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip()
history.insert(0, (old_query, old_resp))
conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query})
else:
break
prompt = prompt[:human_idx].strip()
yield key, {
"instruction": query,
"output": response,
"history": history
}
yield key, {"conversations": conversations}

View File

@@ -88,11 +88,7 @@
},
"belle_multiturn": {
"script_url": "belle_multiturn",
"columns": {
"prompt": "instruction",
"response": "output",
"history": "history"
}
"formatting": "sharegpt"
},
"ultra_chat": {
"script_url": "ultra_chat",
@@ -107,6 +103,13 @@
"alpaca_cot": {
"hf_hub_url": "QingyiSi/Alpaca-CoT"
},
"openorca": {
"hf_hub_url": "Open-Orca/OpenOrca",
"columns": {
"prompt": "question",
"response": "response"
}
},
"mathinstruct": {
"hf_hub_url": "TIGER-Lab/MathInstruct",
"columns": {

View File

@@ -66,6 +66,4 @@ class UltraChat(datasets.GeneratorBasedBuilder):
"from": "human" if i % 2 == 0 else "gpt",
"value": content[i]
} for i in range(len(content))]
yield key, {
"conversations": conversations
}
yield key, {"conversations": conversations}