From 48cab43cb587630ae3bc5dac6c21cc82c14d7a90 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 21 Jan 2024 22:17:48 +0800 Subject: [PATCH] add array param format Former-commit-id: 486cc8d3600397812e3927d43ab4181f4e86f5dd --- data/README.md | 35 ++++++++++++++++++++++------------ data/README_zh.md | 35 ++++++++++++++++++++++------------ src/llmtuner/data/formatter.py | 6 +++++- 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/data/README.md b/data/README.md index 7e56aa30..f2fd7bb1 100644 --- a/data/README.md +++ b/data/README.md @@ -12,14 +12,21 @@ If you are using a custom dataset, please provide your dataset definition in the "ranking": "whether the dataset is a preference dataset or not. (default: false)", "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})", "columns": { - "prompt": "the column name in the dataset containing the prompts. (default: instruction, for alpaca)", - "query": "the column name in the dataset containing the queries. (default: input, for alpaca)", - "response": "the column name in the dataset containing the responses. (default: output, for alpaca)", - "history": "the column name in the dataset containing the histories. (default: None, for alpaca)", - "messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)", - "role": "the key in the message represents the identity. (default: from, for sharegpt)", - "content": "the key in the message represents the content. (default: value, for sharegpt)", - "system": "the column name in the dataset containing the system prompts. (default: None, for both)" + "prompt": "the column name in the dataset containing the prompts. (default: instruction)", + "query": "the column name in the dataset containing the queries. (default: input)", + "response": "the column name in the dataset containing the responses. (default: output)", + "history": "the column name in the dataset containing the histories. (default: None)", + "messages": "the column name in the dataset containing the messages. (default: conversations)", + "system": "the column name in the dataset containing the system prompts. (default: None)", + "tools": "the column name in the dataset containing the tool description. (default: None)" + }, + "tags": { + "role_tag": "the key in the message represents the identity. (default: from)", + "content_tag": "the key in the message represents the content. (default: value)", + "user_tag": "the value of the role_tag represents the user. (default: human)", + "assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)", + "observation_tag": "the value of the role_tag represents the tool results. (default: observation)", + "function_tag": "the value of the role_tag represents the function call. (default: function_call)" } } ``` @@ -91,7 +98,8 @@ The dataset in sharegpt format should follow the below format: "value": "model response" } ], - "system": "system prompt (optional)" + "system": "system prompt (optional)", + "tools": "tool description (optional)" } ] ``` @@ -102,9 +110,12 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be: "dataset_name": { "columns": { "messages": "conversations", - "role": "from", - "content": "value", - "system": "system" + "system": "system", + "tools": "tools" + }, + "tags": { + "role_tag": "from", + "content_tag": "value" } } ``` diff --git a/data/README_zh.md b/data/README_zh.md index cb867a5b..8c46e2ae 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -12,14 +12,21 @@ "ranking": "是否为偏好数据集(可选,默认:False)", "formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)", "columns": { - "prompt": "数据集代表提示词的表头名称(默认:instruction,用于 alpaca 格式)", - "query": "数据集代表请求的表头名称(默认:input,用于 alpaca 格式)", - "response": "数据集代表回答的表头名称(默认:output,用于 alpaca 格式)", - "history": "数据集代表历史对话的表头名称(默认:None,用于 alpaca 格式)", - "messages": "数据集代表消息列表的表头名称(默认:conversations,用于 sharegpt 格式)", - "role": "消息中代表发送者身份的键名(默认:from,用于 sharegpt 格式)", - "content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)", - "system": "数据集代表系统提示的表头名称(默认:None,用于两种格式)" + "prompt": "数据集代表提示词的表头名称(默认:instruction)", + "query": "数据集代表请求的表头名称(默认:input)", + "response": "数据集代表回答的表头名称(默认:output)", + "history": "数据集代表历史对话的表头名称(默认:None)", + "messages": "数据集代表消息列表的表头名称(默认:conversations)", + "system": "数据集代表系统提示的表头名称(默认:None)", + "tools": "数据集代表工具描述的表头名称(默认:None)" + }, + "tags": { + "role_tag": "消息中代表发送者身份的键名(默认:from)", + "content_tag": "消息中代表文本内容的键名(默认:value)", + "user_tag": "消息中代表用户的 role_tag(默认:human)", + "assistant_tag": "消息中代表助手的 role_tag(默认:gpt)", + "observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)", + "function_tag": "消息中代表工具调用的 role_tag(默认:function_call)" } } ``` @@ -91,7 +98,8 @@ "value": "模型回答" } ], - "system": "系统提示词(选填)" + "system": "系统提示词(选填)", + "tools": "工具描述(选填)" } ] ``` @@ -102,9 +110,12 @@ "数据集名称": { "columns": { "messages": "conversations", - "role": "from", - "content": "value", - "system": "system" + "system": "system", + "tools": "tools" + }, + "tags": { + "role_tag": "from", + "content_tag": "value" } } ``` diff --git a/src/llmtuner/data/formatter.py b/src/llmtuner/data/formatter.py index 5eb3a658..e29cb833 100644 --- a/src/llmtuner/data/formatter.py +++ b/src/llmtuner/data/formatter.py @@ -31,12 +31,16 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str: for name, param in tool["parameters"]["properties"].items(): required = ", required" if name in tool["parameters"].get("required", []) else "" enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else "" - param_text += " - {name} ({type}{required}): {desc}{enum}\n".format( + items = ( + ", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else "" + ) + param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format( name=name, type=param.get("type", ""), required=required, desc=param.get("description", ""), enum=enum, + items=items, ) tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(