From e4a424cb6ab5e7da3d8201b10eba2002c05f0e03 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 18 Jan 2024 12:25:42 +0800 Subject: [PATCH] enable cutoff len Former-commit-id: e9513d300c338dfcae98eee7d057bfd00da2da0e --- data/glaive_toolcall_10k.json.REMOVED.git-id | 1 + src/llmtuner/data/aligner.py | 10 +++--- src/llmtuner/data/formatter.py | 3 ++ src/llmtuner/data/parser.py | 4 +-- src/llmtuner/data/preprocess.py | 10 +++--- src/llmtuner/data/template.py | 32 ++++++++++++++++++-- 6 files changed, 46 insertions(+), 14 deletions(-) create mode 100644 data/glaive_toolcall_10k.json.REMOVED.git-id diff --git a/data/glaive_toolcall_10k.json.REMOVED.git-id b/data/glaive_toolcall_10k.json.REMOVED.git-id new file mode 100644 index 00000000..33e71b2c --- /dev/null +++ b/data/glaive_toolcall_10k.json.REMOVED.git-id @@ -0,0 +1 @@ +0fe460c57c1a1dae70cdd72b185129c2aaae4e24 \ No newline at end of file diff --git a/src/llmtuner/data/aligner.py b/src/llmtuner/data/aligner.py index e1e53a0c..ff1a3b44 100644 --- a/src/llmtuner/data/aligner.py +++ b/src/llmtuner/data/aligner.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: - outputs = {"prompt": [], "response": [], "system": [], "tool": []} + outputs = {"prompt": [], "response": [], "system": [], "tools": []} for i in range(len(examples[dataset_attr.prompt])): prompt = [] if dataset_attr.history: @@ -33,13 +33,13 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") outputs["prompt"].append(prompt) outputs["response"].append(response) outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") - outputs["tool"].append("") + outputs["tools"].append("") return outputs def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: - outputs = {"prompt": [], "response": [], "system": [], "tool": []} + outputs = {"prompt": [], "response": [], "system": [], "tools": []} tag_mapping = { dataset_attr.user_tag: Role.USER, dataset_attr.assistant_tag: Role.ASSISTANT, @@ -69,7 +69,7 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr" outputs["prompt"].append(prompt) outputs["response"].append(response) outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") - outputs["tool"].append(examples[dataset_attr.tool][i] if dataset_attr.tool else "") + outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") return outputs @@ -82,7 +82,7 @@ def align_dataset( prompt: [{"role": "user", "content": "..."}] response: [{"role": "assistant", "content": "..."}] system: "..." - tool: "..." + tools: "..." """ if dataset_attr.formatting == "alpaca": convert_func = partial(convert_alpaca, dataset_attr=dataset_attr) diff --git a/src/llmtuner/data/formatter.py b/src/llmtuner/data/formatter.py index 36484d3f..ce7b2819 100644 --- a/src/llmtuner/data/formatter.py +++ b/src/llmtuner/data/formatter.py @@ -93,6 +93,9 @@ class ToolFormatter: def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]: try: tools = json.loads(content) + if not len(tools): + return [""] + if self.type == "default": return [self._default(tools)] except json.JSONDecodeError: diff --git a/src/llmtuner/data/parser.py b/src/llmtuner/data/parser.py index 842461d4..9da5b732 100644 --- a/src/llmtuner/data/parser.py +++ b/src/llmtuner/data/parser.py @@ -29,7 +29,7 @@ class DatasetAttr: history: Optional[str] = None messages: Optional[str] = "conversations" - tool: Optional[str] = None + tools: Optional[str] = None role_tag: Optional[str] = "from" content_tag: Optional[str] = "value" @@ -86,7 +86,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: if dataset_attr.formatting == "alpaca": column_names = ["prompt", "query", "response", "history"] else: - column_names = ["messages", "tool"] + column_names = ["messages", "tools"] column_names += ["system"] for column_name in column_names: diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 1e6eadb1..81caeda7 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -58,7 +58,7 @@ def preprocess_supervised_dataset( messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = [], [] for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len + tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len )): if data_args.train_on_prompt: source_mask = source_ids @@ -97,7 +97,7 @@ def preprocess_packed_supervised_dataset( messages = examples["prompt"][i] + examples["response"][i] for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tool"][i] + tokenizer, messages, examples["system"][i], examples["tools"][i] )): if data_args.train_on_prompt: source_mask = source_ids @@ -141,7 +141,7 @@ def preprocess_unsupervised_dataset( messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = template.encode_oneturn( - tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len + tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len ) if template.efficient_eos: @@ -170,10 +170,10 @@ def preprocess_pairwise_dataset( rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] prompt_ids, chosen_ids = template.encode_oneturn( - tokenizer, chosen_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len + tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len ) _, rejected_ids = template.encode_oneturn( - tokenizer, rejected_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len + tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len ) if template.efficient_eos: diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 935d951c..f4537d86 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -95,7 +95,21 @@ class Template: encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) - return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)] + # TODO: need to improve + encoded_pairs = [] + total_length = 0 + for i in range(0, len(encoded_messages), 2): + if total_length >= cutoff_len: + break + + encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length] + total_length += len(encoded_messages[i]) + + encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)] + total_length += len(encoded_messages[i+1]) + encoded_pairs.append((encoded_messages[i], encoded_messages[i+1])) + + return encoded_pairs def _convert_elements_to_ids( self, @@ -161,7 +175,21 @@ class Llama2Template(Template): encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) - return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)] + # TODO: need to improve + encoded_pairs = [] + total_length = 0 + for i in range(0, len(encoded_messages), 2): + if total_length >= cutoff_len: + break + + encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length] + total_length += len(encoded_messages[i]) + + encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)] + total_length += len(encoded_messages[i+1]) + encoded_pairs.append((encoded_messages[i], encoded_messages[i+1])) + + return encoded_pairs templates: Dict[str, Template] = {}