improve KTO impl., replace datasets

Former-commit-id: c450ee87a3
This commit is contained in:
hiyouga
2024-05-18 03:44:56 +08:00
parent 97469892c3
commit 13d7b48efe
66 changed files with 46444 additions and 28125 deletions

View File

@@ -19,7 +19,10 @@ If you are using a custom dataset, please add your **dataset description** to `d
"messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)"
"images": "the column name in the dataset containing the image inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
},
"tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)",
@@ -42,13 +45,13 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
```json
[
{
"instruction": "user instruction (required)",
"input": "user input (optional)",
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"system": "system prompt (optional)",
"history": [
["user instruction in the first round (optional)", "model response in the first round (optional)"],
["user instruction in the second round (optional)", "model response in the second round (optional)"]
["human instruction in the first round (optional)", "model response in the first round (optional)"],
["human instruction in the second round (optional)", "model response in the second round (optional)"]
]
}
]
@@ -69,7 +72,7 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
}
```
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
The `query` column will be concatenated with the `prompt` column and used as the human prompt, then the human prompt would be `prompt\nquery`. The `response` column represents the model response.
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning.
@@ -98,12 +101,10 @@ For the **preference datasets**, the `response` column should be a string list w
```json
[
{
"instruction": "user instruction",
"input": "user input",
"output": [
"chosen answer",
"rejected answer"
]
"instruction": "human instruction",
"input": "human input",
"chosen": "chosen answer",
"rejected": "rejected answer"
}
]
```
@@ -117,7 +118,8 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
@@ -132,7 +134,7 @@ The dataset in **sharegpt** format should follow the below format:
"conversations": [
{
"from": "human",
"value": "user instruction"
"value": "human instruction"
},
{
"from": "gpt",
@@ -179,7 +181,7 @@ We also supports the dataset in the **openai** format:
},
{
"role": "user",
"content": "user instruction"
"content": "human instruction"
},
{
"role": "assistant",

View File

@@ -19,7 +19,10 @@
"messages": "数据集代表消息列表的表头名称默认conversations",
"system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None"
"images": "数据集代表图像输入的表头名称默认None",
"chosen": "数据集代表更优回复的表头名称默认None",
"rejected": "数据集代表更差回复的表头名称默认None",
"kto_tag": "数据集代表 KTO 标签的表头名称默认None"
},
"tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from",
@@ -42,8 +45,8 @@
```json
[
{
"instruction": "用户指令(必填)",
"input": "用户输入(选填)",
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"system": "系统提示词(选填)",
"history": [
@@ -69,7 +72,7 @@
}
```
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery``response` 列对应的内容为模型回答。
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为人类指令,即人类指令为 `prompt\nquery``response` 列对应的内容为模型回答。
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。
@@ -98,12 +101,10 @@
```json
[
{
"instruction": "用户指令",
"input": "用户输入",
"output": [
"质回答",
"劣质回答"
]
"instruction": "人类指令",
"input": "人类输入",
"chosen": "优质回答",
"rejected": "质回答"
}
]
```
@@ -117,7 +118,8 @@
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
@@ -132,7 +134,7 @@
"conversations": [
{
"from": "human",
"value": "用户指令"
"value": "人类指令"
},
{
"from": "gpt",
@@ -165,7 +167,7 @@
}
```
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
其中 `messages` 列应当是一个列表,且符合 `人类/模型/人类/模型/人类/模型` 的顺序。
我们同样支持 **openai** 格式的数据集:
@@ -179,7 +181,7 @@
},
{
"role": "user",
"content": "用户指令"
"content": "人类指令"
},
{
"role": "assistant",

View File

@@ -1 +0,0 @@
3779ddbc040543ab1834ef216c983d6fcc06cc9a

View File

@@ -1 +0,0 @@
a97cf9475291591843976554878568e046d8a46d

5002
data/alpaca_en_demo.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1 +0,0 @@
25508714b7879a1e5a6764ba7f979a980f549f1a

View File

@@ -1 +0,0 @@
7cb6a7d11455bddc3d495750a2392683d775b184

5002
data/alpaca_zh_demo.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1 +0,0 @@
f5cb08305ff5dc9c17a09809c54c8c8834aadc70

View File

@@ -1 +0,0 @@
aee47b7b443496e37808d7f34ef10403ff99bcc3

View File

@@ -1,48 +1,23 @@
{
"alpaca_en": {
"file_name": "alpaca_data_en_52k.json"
},
"alpaca_zh": {
"file_name": "alpaca_data_zh_51k.json"
},
"alpaca_gpt4_en": {
"file_name": "alpaca_gpt4_data_en.json"
},
"alpaca_gpt4_zh": {
"file_name": "alpaca_gpt4_data_zh.json"
},
"identity": {
"file_name": "identity.json"
},
"oaast_sft_zh": {
"file_name": "oaast_sft_zh.json",
"alpaca_en_demo": {
"file_name": "alpaca_en_demo.json"
},
"alpaca_zh_demo": {
"file_name": "alpaca_zh_demo.json"
},
"glaive_toolcall_en_demo": {
"file_name": "glaive_toolcall_en_demo.json",
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
"messages": "conversations",
"tools": "tools"
}
},
"lima": {
"file_name": "lima.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"kto-mix-test": {
"file_name": "kto-mix-test.json",
"file_sha1": "91b59f657007dc4b17529fc643v9b9cd6d640fha",
"columns": {
"prompt": "instruction",
"response": "output",
"tag": "tag"
}
},
"glaive_toolcall": {
"file_name": "glaive_toolcall_10k.json",
"glaive_toolcall_zh_demo": {
"file_name": "glaive_toolcall_zh_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
@@ -63,15 +38,42 @@
"assistant_tag": "assistant"
}
},
"example": {
"script_url": "example_dataset",
"alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en"
},
"alpaca_zh": {
"hf_hub_url": "llamafactory/alpaca_zh",
"ms_hub_url": "llamafactory/alpaca_zh"
},
"alpaca_gpt4_en": {
"hf_hub_url": "llamafactory/alpaca_gpt4_en",
"ms_hub_url": "llamafactory/alpaca_gpt4_en"
},
"alpaca_gpt4_zh": {
"hf_hub_url": "llamafactory/alpaca_gpt4_zh",
"ms_hub_url": "llamafactory/alpaca_gpt4_zh"
},
"glaive_toolcall_en": {
"hf_hub_url": "llamafactory/glaive_toolcall_en",
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
"messages": "conversations",
"tools": "tools"
}
},
"glaive_toolcall_zh": {
"hf_hub_url": "llamafactory/glaive_toolcall_zh",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"tools": "tools"
}
},
"lima": {
"hf_hub_url": "llamafactory/lima",
"formatting": "sharegpt"
},
"guanaco": {
"hf_hub_url": "JosephusCheung/GuanacoDataset",
"ms_hub_url": "AI-ModelScope/GuanacoDataset"
@@ -240,6 +242,12 @@
"response": "text"
}
},
"stem_zh": {
"hf_hub_url": "hfl/stem_zh_instruction"
},
"ruozhiba_gpt4": {
"hf_hub_url": "hfl/ruozhiba_gpt4_turbo"
},
"llava_150k_en": {
"hf_hub_url": "BUAADreamer/llava-en-zh-300k",
"subset": "en",
@@ -297,73 +305,105 @@
"ultrachat_de": {
"hf_hub_url": "mayflowergmbh/ultra-chat_de"
},
"hh_rlhf_en": {
"script_url": "hh_rlhf_en",
"dpo_en_demo": {
"file_name": "dpo_en_demo.json",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"response": "output",
"history": "history"
},
"ranking": true
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"oaast_rm_zh": {
"file_name": "oaast_rm_zh.json",
"dpo_zh_demo": {
"file_name": "dpo_zh_demo.json",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
},
"ranking": true
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"comparison_gpt4_en": {
"file_name": "comparison_gpt4_data_en.json",
"ranking": true
"dpo_mix_en": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "en",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"comparison_gpt4_zh": {
"file_name": "comparison_gpt4_data_zh.json",
"ranking": true
"dpo_mix_zh": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "zh",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
},
"orca_rlhf": {
"file_name": "orca_rlhf.json",
"orca_pairs": {
"hf_hub_url": "Intel/orca_dpo_pairs",
"ranking": true,
"columns": {
"prompt": "question",
"response": "answer",
"chosen": "chosen",
"rejected": "rejected",
"system": "system"
}
},
"hh_rlhf_en": {
"script_url": "hh_rlhf_en",
"ranking": true,
"columns": {
"prompt": "instruction",
"chosen": "chosen",
"rejected": "rejected",
"history": "history"
}
},
"nectar_rm": {
"hf_hub_url": "AstraMindAI/RLAIF-Nectar",
"ms_hub_url": "AI-ModelScope/RLAIF-Nectar",
"ranking": true
},
"dpo_mix_en": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "en",
"ranking": true,
"columns": {
"prompt": "prompt",
"response": "answer",
"system": "system",
"history": "history"
}
},
"dpo_mix_zh": {
"hf_hub_url": "hiyouga/DPO-En-Zh-20k",
"subset": "zh",
"ranking": true,
"columns": {
"prompt": "prompt",
"response": "answer",
"system": "system",
"history": "history"
}
},
"orca_dpo_de": {
"hf_hub_url": "mayflowergmbh/intel_orca_dpo_pairs_de",
"ranking": true
},
"kto_en_demo": {
"file_name": "kto_en_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"kto_tag": "label"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"kto_mix_en": {
"hf_hub_url": "argilla/kto-mix-15k",
"formatting": "sharegpt",
"columns": {
"messages": "completion",
"kto_tag": "label"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"wiki_demo": {
"file_name": "wiki_demo.txt",
"columns": {

7226
data/dpo_en_demo.json Normal file

File diff suppressed because one or more lines are too long

5058
data/dpo_zh_demo.json Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1,37 +0,0 @@
import json
from typing import Any, Dict, Generator, List, Tuple
import datasets
_DESCRIPTION = "An example of dataset."
_CITATION = ""
_HOMEPAGE = ""
_LICENSE = ""
_URL = "examples.json"
class ExampleDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features(
{
"instruction": datasets.Value("string"),
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download(_URL)
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
for key, example in enumerate(example_dataset):
yield key, example

View File

@@ -1,20 +0,0 @@
[
{
"instruction": "听起来很不错。人工智能可能在哪些方面面临挑战呢?",
"input": "",
"output": "人工智能面临的挑战包括数据隐私、安全和道德方面的问题,以及影响就业机会的自动化等问题。",
"history": [
["你好,你能帮我解答一个问题吗?", "当然,请问有什么问题?"],
["我想了解人工智能的未来发展方向,你有什么想法吗?", "人工智能在未来的发展方向可能包括更强大的机器学习算法,更先进的自然语言处理技术,以及更加智能的机器人。"]
]
},
{
"instruction": "好的,谢谢你!",
"input": "",
"output": "不客气,有其他需要帮忙的地方可以继续问我。",
"history": [
["你好,能告诉我今天天气怎么样吗?", "当然可以,请问您所在的城市是哪里?"],
["我在纽约。", "纽约今天晴间多云气温最高约26摄氏度最低约18摄氏度记得注意保暖喔。"]
]
}
]

View File

@@ -1 +0,0 @@
4748dff00d1dc42768a5b6cc772143c313017812

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -79,5 +79,5 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
break
prompt = prompt[:human_idx]
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history}
key += 1

File diff suppressed because one or more lines are too long

5398
data/kto_en_demo.json Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
736bcedea2b24a1414765c6d69cbdafaea839f3c

30
data/wiki_demo.txt Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
c9cf509b7fdac5490cfd6dae72c2d7b8a60af6cb