improve KTO impl., replace datasets

Former-commit-id: e56a57ddcf061de6e4acc8679f7dbf0b68364986
This commit is contained in:
hiyouga
2024-05-18 03:44:56 +08:00
parent 65715cab9d
commit d24969bb7e
61 changed files with 39218 additions and 440 deletions

View File

@@ -19,7 +19,10 @@ If you are using a custom dataset, please add your **dataset description** to `d
"messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)"
"images": "the column name in the dataset containing the image inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
},
"tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)",
@@ -42,13 +45,13 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
```json
[
{
"instruction": "user instruction (required)",
"input": "user input (optional)",
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"system": "system prompt (optional)",
"history": [
["user instruction in the first round (optional)", "model response in the first round (optional)"],
["user instruction in the second round (optional)", "model response in the second round (optional)"]
["human instruction in the first round (optional)", "model response in the first round (optional)"],
["human instruction in the second round (optional)", "model response in the second round (optional)"]
]
}
]
@@ -69,7 +72,7 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
}
```
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
The `query` column will be concatenated with the `prompt` column and used as the human prompt, then the human prompt would be `prompt\nquery`. The `response` column represents the model response.
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning.
@@ -98,12 +101,10 @@ For the **preference datasets**, the `response` column should be a string list w
```json
[
{
"instruction": "user instruction",
"input": "user input",
"output": [
"chosen answer",
"rejected answer"
]
"instruction": "human instruction",
"input": "human input",
"chosen": "chosen answer",
"rejected": "rejected answer"
}
]
```
@@ -117,7 +118,8 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
@@ -132,7 +134,7 @@ The dataset in **sharegpt** format should follow the below format:
"conversations": [
{
"from": "human",
"value": "user instruction"
"value": "human instruction"
},
{
"from": "gpt",
@@ -179,7 +181,7 @@ We also supports the dataset in the **openai** format:
},
{
"role": "user",
"content": "user instruction"
"content": "human instruction"
},
{
"role": "assistant",

View File

@@ -19,7 +19,10 @@
"messages": "数据集代表消息列表的表头名称默认conversations",
"system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None"
"images": "数据集代表图像输入的表头名称默认None",
"chosen": "数据集代表更优回复的表头名称默认None",
"rejected": "数据集代表更差回复的表头名称默认None",
"kto_tag": "数据集代表 KTO 标签的表头名称默认None"
},
"tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from",
@@ -42,8 +45,8 @@
```json
[
{
"instruction": "用户指令(必填)",
"input": "用户输入(选填)",
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"system": "系统提示词(选填)",
"history": [
@@ -69,7 +72,7 @@
}
```
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery``response` 列对应的内容为模型回答。
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为人类指令,即人类指令为 `prompt\nquery``response` 列对应的内容为模型回答。
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。
@@ -98,12 +101,10 @@
```json
[
{
"instruction": "用户指令",
"input": "用户输入",
"output": [
"质回答",
"劣质回答"
]
"instruction": "人类指令",
"input": "人类输入",
"chosen": "优质回答",
"rejected": "质回答"
}
]
```
@@ -117,7 +118,8 @@
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
@@ -132,7 +134,7 @@
"conversations": [
{
"from": "human",
"value": "用户指令"
"value": "人类指令"
},
{
"from": "gpt",
@@ -165,7 +167,7 @@
}
```
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
其中 `messages` 列应当是一个列表,且符合 `人类/模型/人类/模型/人类/模型` 的顺序。
我们同样支持 **openai** 格式的数据集:
@@ -179,7 +181,7 @@
},
{
"role": "user",
"content": "用户指令"
"content": "人类指令"
},
{
"role": "assistant",

View File

@@ -1 +0,0 @@
3779ddbc040543ab1834ef216c983d6fcc06cc9a

View File

@@ -1 +0,0 @@
a97cf9475291591843976554878568e046d8a46d

5002
data/alpaca_en_demo.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1 +0,0 @@
25508714b7879a1e5a6764ba7f979a980f549f1a

View File

@@ -1 +0,0 @@
7cb6a7d11455bddc3d495750a2392683d775b184

5002
data/alpaca_zh_demo.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1 +0,0 @@
f5cb08305ff5dc9c17a09809c54c8c8834aadc70

View File

@@ -1 +0,0 @@
aee47b7b443496e37808d7f34ef10403ff99bcc3

View File

@@ -1,48 +1,23 @@
{
"alpaca_en": {
"file_name": "alpaca_data_en_52k.json"
},
"alpaca_zh": {
"file_name": "alpaca_data_zh_51k.json"
},
"alpaca_gpt4_en": {
"file_name": "alpaca_gpt4_data_en.json"
},
"alpaca_gpt4_zh": {
"file_name": "alpaca_gpt4_data_zh.json"
},
"identity": {
"file_name": "identity.json"
},
"oaast_sft_zh": {
"file_name": "oaast_sft_zh.json",
"alpaca_en_demo": {
"file_name": "alpaca_en_demo.json"
},
"alpaca_zh_demo": {
"file_name": "alpaca_zh_demo.json"
},
"glaive_toolcall_en_demo": {
"file_name": "glaive_toolcall_en_demo.json",
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
"messages": "conversations",
"tools": "tools"
}
},
"lima": {
"file_name": "lima.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"kto-mix-test": {
"file_name": "kto-mix-test.json",
"file_sha1": "91b59f657007dc4b17529fc643v9b9cd6d640fha",
"columns": {
"prompt": "instruction",
"response": "output",
"tag": "tag"
}
},
"glaive_toolcall": {
"file_name": "glaive_toolcall_10k.json",
"glaive_toolcall_zh_demo": {
"file_name": "glaive_toolcall_zh_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
@@ -63,15 +38,42 @@
"assistant_tag": "assistant"
}
},
"example": {
"script_url": "example_dataset",
"alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en"
},
"alpaca_zh": {
"hf_hub_url": "llamafactory/alpaca_zh",
"ms_hub_url": "llamafactory/alpaca_zh"
},
"alpaca_gpt4_en": {
"hf_hub_url": "llamafactory/alpaca_gpt4_en",
"ms_hub_url": "llamafactory/alpaca_gpt4_en"
},
"alpaca_gpt4_zh": {
"hf_hub_url": "llamafactory/alpaca_gpt4_zh",
"ms_hub_url": "llamafactory/alpaca_gpt4_zh"
},
"glaive_toolcall_en": {
"hf_hub_url": "llamafactory/glaive_toolcall_en",
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
"messages": "conversations",
"tools": "tools"
}
},
"glaive_toolcall_zh": {
"hf_hub_url": "llamafactory/glaive_toolcall_zh",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"tools": "tools"
}
},
"lima": {
"hf_hub_url": "llamafactory/lima",
"formatting": "sharegpt"
},
"guanaco": {
"hf_hub_url": "JosephusCheung/GuanacoDataset",
"ms_hub_url": "AI-ModelScope/GuanacoDataset"
@@ -240,6 +242,12 @@
"response": "text"
}
},
"stem_zh": {
"hf_hub_url": "hfl/stem_zh_instruction"
},
"ruozhiba_gpt4": {
"hf_hub_url": "hfl/ruozhiba_gpt4_turbo"
},
"llava_150k_en": {
"hf_hub_url": "BUAADreamer/llava-en-zh-300k",
"subset": "en",
@@ -297,73 +305,105 @@
"ultrachat_de": {
"hf_hub_url": "mayflowergmbh/ultra-chat_de"
},
"hh_rlhf_en": {
"script_url": "hh_rlhf_en",
"dpo_en_demo": {
"file_name": "dpo_en_demo.json",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"response": "output",
"history": "history"
},
"ranking": true
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"oaast_rm_zh": {
"file_name": "oaast_rm_zh.json",
"dpo_zh_demo": {
"file_name": "dpo_zh_demo.json",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
},
"ranking": true
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"comparison_gpt4_en": {
"file_name": "comparison_gpt4_data_en.json",
"ranking": true
"dpo_mix_en": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "en",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"comparison_gpt4_zh": {
"file_name": "comparison_gpt4_data_zh.json",
"ranking": true
"dpo_mix_zh": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "zh",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"orca_rlhf": {
"file_name": "orca_rlhf.json",
"orca_pairs": {
"hf_hub_url": "Intel/orca_dpo_pairs",
"ranking": true,
"columns": {
"prompt": "question",
"response": "answer",
"chosen": "chosen",
"rejected": "rejected",
"system": "system"
}
},
"hh_rlhf_en": {
"script_url": "hh_rlhf_en",
"ranking": true,
"columns": {
"prompt": "instruction",
"chosen": "chosen",
"rejected": "rejected",
"history": "history"
}
},
"nectar_rm": {
"hf_hub_url": "AstraMindAI/RLAIF-Nectar",
"ms_hub_url": "AI-ModelScope/RLAIF-Nectar",
"ranking": true
},
"dpo_mix_en": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "en",
"ranking": true,
"columns": {
"prompt": "prompt",
"response": "answer",
"system": "system",
"history": "history"
}
},
"dpo_mix_zh": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "zh",
"ranking": true,
"columns": {
"prompt": "prompt",
"response": "answer",
"system": "system",
"history": "history"
}
},
"orca_dpo_de": {
"hf_hub_url": "mayflowergmbh/intel_orca_dpo_pairs_de",
"ranking": true
},
"kto_en_demo": {
"file_name": "kto_en_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"kto_tag": "label"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"kto_mix_en": {
"hf_hub_url": "argilla/kto-mix-15k",
"formatting": "sharegpt",
"columns": {
"messages": "completion",
"kto_tag": "label"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"wiki_demo": {
"file_name": "wiki_demo.txt",
"columns": {

5058
data/dpo_zh_demo.json Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1,37 +0,0 @@
import json
from typing import Any, Dict, Generator, List, Tuple
import datasets
_DESCRIPTION = "An example of dataset."
_CITATION = ""
_HOMEPAGE = ""
_LICENSE = ""
_URL = "examples.json"
class ExampleDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features(
{
"instruction": datasets.Value("string"),
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download(_URL)
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
for key, example in enumerate(example_dataset):
yield key, example

View File

@@ -1,20 +0,0 @@
[
{
"instruction": "听起来很不错。人工智能可能在哪些方面面临挑战呢?",
"input": "",
"output": "人工智能面临的挑战包括数据隐私、安全和道德方面的问题,以及影响就业机会的自动化等问题。",
"history": [
["你好,你能帮我解答一个问题吗?", "当然,请问有什么问题?"],
["我想了解人工智能的未来发展方向,你有什么想法吗?", "人工智能在未来的发展方向可能包括更强大的机器学习算法,更先进的自然语言处理技术,以及更加智能的机器人。"]
]
},
{
"instruction": "好的,谢谢你!",
"input": "",
"output": "不客气,有其他需要帮忙的地方可以继续问我。",
"history": [
["你好,能告诉我今天天气怎么样吗?", "当然可以,请问您所在的城市是哪里?"],
["我在纽约。", "纽约今天晴间多云气温最高约26摄氏度最低约18摄氏度记得注意保暖喔。"]
]
}
]

View File

@@ -1 +0,0 @@
4748dff00d1dc42768a5b6cc772143c313017812

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -79,5 +79,5 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
break
prompt = prompt[:human_idx]
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history}
key += 1

5398
data/kto_en_demo.json Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
736bcedea2b24a1414765c6d69cbdafaea839f3c

30
data/wiki_demo.txt Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
c9cf509b7fdac5490cfd6dae72c2d7b8a60af6cb