diff --git a/data/dataset_info.json b/data/dataset_info.json index f5fe4edd..ee35ac0e 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -198,12 +198,12 @@ "hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k", "formatting": "sharegpt" }, - "glaive_func_call": { - "file_name": "glaive_func_call.json", + "glaive_toolcall": { + "file_name": "glaive_toolcall_10k.json", "formatting": "sharegpt", "columns": { "messages": "conversations", - "tool": "tools" + "tools": "tools" } }, "hh_rlhf_en": { diff --git a/data/glaive_func_call.json b/data/glaive_func_call.json deleted file mode 100644 index e1d6cce7..00000000 --- a/data/glaive_func_call.json +++ /dev/null @@ -1,68 +0,0 @@ -[ - { - "conversations": [ - { - "from": "human", - "value": "I need a new password. Can you generate one for me?" - }, - { - "from": "gpt", - "value": "Of course. How long would you like your password to be? And would you like it to include symbols?" - }, - { - "from": "human", - "value": "I would like it to be 12 characters long and yes, please include symbols." - }, - { - "from": "function_call", - "value": "{\"name\": \"generate_password\", \"arguments\": {\"length\": 12, \"include_symbols\": true}}" - }, - { - "from": "observation", - "value": "{\"password\": \"4&7j#9@1Q6*\"}" - }, - { - "from": "gpt", - "value": "Here is your new password: 4&7j#9@1Q6*. Please make sure to save it in a secure location." - } - ], - "tools": "[{\"name\": \"generate_password\", \"description\": \"Generate a random password\", \"parameters\": {\"type\": \"object\", \"properties\": {\"length\": {\"type\": \"integer\", \"description\": \"The length of the password\"}, \"include_symbols\": {\"type\": \"boolean\", \"description\": \"Whether to include symbols in the password\"}}, \"required\": [\"length\"]}}, {\"name\": \"create_task\", \"description\": \"Create a new task in a task management system\", \"parameters\": {\"type\": \"object\", \"properties\": {\"title\": {\"type\": \"string\", \"description\": \"The title of the task\"}, \"due_date\": {\"type\": \"string\", \"format\": \"date\", \"description\": \"The due date of the task\"}, \"priority\": {\"type\": \"string\", \"enum\": [\"low\", \"medium\", \"high\"], \"description\": \"The priority of the task\"}}, \"required\": [\"title\", \"due_date\", \"priority\"]}}]" - }, - { - "conversations": [ - { - "from": "human", - "value": "Can you tell me the latest news headlines for the United States?" - }, - { - "from": "function_call", - "value": "{\"name\": \"get_news_headlines\", \"arguments\": {\"country\": \"United States\"}}" - }, - { - "from": "observation", - "value": "{\"headlines\": [\"Biden announces new vaccine mandates\", \"Hurricane Ida devastates Louisiana\", \"Apple unveils new iPhone\", \"NASA's Perseverance rover collects first Mars rock sample\"]}" - }, - { - "from": "gpt", - "value": "Here are the latest news headlines for the United States:" - }, - { - "from": "human", - "value": "That's interesting. What about the news in France?" - }, - { - "from": "function_call", - "value": "{\"name\": \"get_news_headlines\", \"arguments\": {\"country\": \"France\"}}" - }, - { - "from": "observation", - "value": "{\"headlines\": [\"France recalls ambassadors to US and Australia\", \"French election: Macron's party braces for tough fight\", \"Louvre Museum to undergo major overhaul\", \"France to offer free birth control to all women under 25\"]}" - }, - { - "from": "gpt", - "value": "Here are the latest news headlines for France:" - } - ], - "tools": "[{\"name\": \"get_news_headlines\", \"description\": \"Get the latest news headlines\", \"parameters\": {\"type\": \"object\", \"properties\": {\"country\": {\"type\": \"string\", \"description\": \"The country for which to fetch news\"}}, \"required\": [\"country\"]}}]" - } -] \ No newline at end of file diff --git a/data/glaive_toolcall_10k.json.REMOVED.git-id b/data/glaive_toolcall_10k.json.REMOVED.git-id new file mode 100644 index 00000000..33e71b2c --- /dev/null +++ b/data/glaive_toolcall_10k.json.REMOVED.git-id @@ -0,0 +1 @@ +0fe460c57c1a1dae70cdd72b185129c2aaae4e24 \ No newline at end of file diff --git a/src/llmtuner/data/aligner.py b/src/llmtuner/data/aligner.py index e1e53a0c..ff1a3b44 100644 --- a/src/llmtuner/data/aligner.py +++ b/src/llmtuner/data/aligner.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: - outputs = {"prompt": [], "response": [], "system": [], "tool": []} + outputs = {"prompt": [], "response": [], "system": [], "tools": []} for i in range(len(examples[dataset_attr.prompt])): prompt = [] if dataset_attr.history: @@ -33,13 +33,13 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") outputs["prompt"].append(prompt) outputs["response"].append(response) outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") - outputs["tool"].append("") + outputs["tools"].append("") return outputs def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: - outputs = {"prompt": [], "response": [], "system": [], "tool": []} + outputs = {"prompt": [], "response": [], "system": [], "tools": []} tag_mapping = { dataset_attr.user_tag: Role.USER, dataset_attr.assistant_tag: Role.ASSISTANT, @@ -69,7 +69,7 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr" outputs["prompt"].append(prompt) outputs["response"].append(response) outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") - outputs["tool"].append(examples[dataset_attr.tool][i] if dataset_attr.tool else "") + outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") return outputs @@ -82,7 +82,7 @@ def align_dataset( prompt: [{"role": "user", "content": "..."}] response: [{"role": "assistant", "content": "..."}] system: "..." - tool: "..." + tools: "..." """ if dataset_attr.formatting == "alpaca": convert_func = partial(convert_alpaca, dataset_attr=dataset_attr) diff --git a/src/llmtuner/data/formatter.py b/src/llmtuner/data/formatter.py index 36484d3f..ce7b2819 100644 --- a/src/llmtuner/data/formatter.py +++ b/src/llmtuner/data/formatter.py @@ -93,6 +93,9 @@ class ToolFormatter: def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]: try: tools = json.loads(content) + if not len(tools): + return [""] + if self.type == "default": return [self._default(tools)] except json.JSONDecodeError: diff --git a/src/llmtuner/data/parser.py b/src/llmtuner/data/parser.py index 842461d4..9da5b732 100644 --- a/src/llmtuner/data/parser.py +++ b/src/llmtuner/data/parser.py @@ -29,7 +29,7 @@ class DatasetAttr: history: Optional[str] = None messages: Optional[str] = "conversations" - tool: Optional[str] = None + tools: Optional[str] = None role_tag: Optional[str] = "from" content_tag: Optional[str] = "value" @@ -86,7 +86,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: if dataset_attr.formatting == "alpaca": column_names = ["prompt", "query", "response", "history"] else: - column_names = ["messages", "tool"] + column_names = ["messages", "tools"] column_names += ["system"] for column_name in column_names: diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 1e6eadb1..81caeda7 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -58,7 +58,7 @@ def preprocess_supervised_dataset( messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = [], [] for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len + tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len )): if data_args.train_on_prompt: source_mask = source_ids @@ -97,7 +97,7 @@ def preprocess_packed_supervised_dataset( messages = examples["prompt"][i] + examples["response"][i] for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tool"][i] + tokenizer, messages, examples["system"][i], examples["tools"][i] )): if data_args.train_on_prompt: source_mask = source_ids @@ -141,7 +141,7 @@ def preprocess_unsupervised_dataset( messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = template.encode_oneturn( - tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len + tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len ) if template.efficient_eos: @@ -170,10 +170,10 @@ def preprocess_pairwise_dataset( rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] prompt_ids, chosen_ids = template.encode_oneturn( - tokenizer, chosen_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len + tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len ) _, rejected_ids = template.encode_oneturn( - tokenizer, rejected_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len + tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len ) if template.efficient_eos: diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 935d951c..f4537d86 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -95,7 +95,21 @@ class Template: encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) - return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)] + # TODO: need to improve + encoded_pairs = [] + total_length = 0 + for i in range(0, len(encoded_messages), 2): + if total_length >= cutoff_len: + break + + encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length] + total_length += len(encoded_messages[i]) + + encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)] + total_length += len(encoded_messages[i+1]) + encoded_pairs.append((encoded_messages[i], encoded_messages[i+1])) + + return encoded_pairs def _convert_elements_to_ids( self, @@ -161,7 +175,21 @@ class Llama2Template(Template): encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) - return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)] + # TODO: need to improve + encoded_pairs = [] + total_length = 0 + for i in range(0, len(encoded_messages), 2): + if total_length >= cutoff_len: + break + + encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length] + total_length += len(encoded_messages[i]) + + encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)] + total_length += len(encoded_messages[i+1]) + encoded_pairs.append((encoded_messages[i], encoded_messages[i+1])) + + return encoded_pairs templates: Dict[str, Template] = {}