From dbb284b5a268f91a802278cc56dedb0c08b91f92 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 4 Aug 2023 23:27:55 +0800 Subject: [PATCH] fix encode Former-commit-id: 8172ad1b5e3fa0b224d761ce6069d0db4397da2d --- src/llmtuner/dsets/preprocess.py | 2 +- src/llmtuner/extras/template.py | 34 +++++++++++++++++--------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 2482abe3..0f28ce1e 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -55,7 +55,7 @@ def preprocess_dataset( for query, response, history, prefix in construct_example(examples): input_ids, labels = [], [] - for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): + for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): # TODO: fix bos if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] if len(target_ids) > data_args.max_target_length - 1: # eos token diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 46089413..4ffd569a 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -29,7 +29,7 @@ class Template: encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history) prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: - prompt_ids = prompt_ids + query_ids + resp_ids + prompt_ids = prompt_ids + query_ids + resp_ids + [tokenizer.eos_token_id] prompt_ids = prompt_ids + encoded_pairs[-1][0] return prompt_ids, encoded_pairs[-1][1] @@ -96,7 +96,9 @@ class Template: token_ids = [] for elem in context: if isinstance(elem, str): - elem = elem.format(query=query) + subelems = elem.split("{{query}}") + if len(subelems) > 1: + elem = subelems[0] + query + subelems[1] token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False) elif isinstance(elem, dict): token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] @@ -165,7 +167,7 @@ register_template( name="vanilla", prefix=[], prompt=[ - "{query}" + "{{query}}" ], sep=[], stop_words=[], @@ -183,7 +185,7 @@ register_template( "The assistant gives helpful, detailed, and polite answers to the user's questions." ], prompt=[ - "Human: {query}\nAssistant: " + "Human: {{query}}\nAssistant: " ], sep=[ "\n" @@ -211,7 +213,7 @@ register_template( "If you don't know the answer to a question, please don't share false information.\n<>\n\n" ], prompt=[ - "[INST] {query} [/INST] " + "[INST] {{query}} [/INST] " ], sep=[ {"token": ""} @@ -232,7 +234,7 @@ register_template( "Write a response that appropriately completes the request." ], prompt=[ - "### Instruction:\n{query}\n\n### Response:\n" + "### Instruction:\n{{query}}\n\n### Response:\n" ], sep=[ "\n\n" @@ -253,7 +255,7 @@ register_template( "The assistant gives helpful, detailed, and polite answers to the user's questions." ], prompt=[ - "USER: {query} ASSISTANT: " + "USER: {{query}} ASSISTANT: " ], sep=[], stop_words=[], @@ -268,7 +270,7 @@ register_template( name="belle", prefix=[], prompt=[ - "Human: {query}\n\nBelle: " + "Human: {{query}}\n\nBelle: " ], sep=[ "\n\n" @@ -285,7 +287,7 @@ register_template( name="linly", prefix=[], prompt=[ - "User: {query}\nBot: " + "User: {{query}}\nBot: " ], sep=[ "\n" @@ -302,7 +304,7 @@ register_template( name="billa", prefix=[], prompt=[ - "Human: {query}\nAssistant: " + "Human: {{query}}\nAssistant: " ], sep=[ "\n" @@ -320,7 +322,7 @@ register_template( prefix=[], prompt=[ {"token": ""}, - ":{query}\n", + ":{{query}}\n", {"token": ""}, ":" ], @@ -342,7 +344,7 @@ register_template( "The assistant gives helpful, detailed, and polite answers to the human's questions." ], prompt=[ - "Human: {query}###Assistant: " + "Human: {{query}}###Assistant: " ], sep=[ "###" @@ -360,7 +362,7 @@ register_template( prefix=[], prompt=[ {"token": "<|User|>"}, - ":{query}", + ":{{query}}", {"token": ""}, "\n", {"token": "<|Bot|>"}, @@ -385,7 +387,7 @@ register_template( prefix=[], prompt=[ {"token": ""}, - "{query}", + "{{query}}", {"token": ""} ], sep=[], @@ -406,7 +408,7 @@ register_template( ], prompt=[ {"token": "<|user|>"}, - "\n{query}", + "\n{{query}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"} @@ -433,7 +435,7 @@ register_template( ], prompt=[ {"token": "<|im_start|>"}, - "user\n{query}", + "user\n{{query}}", {"token": "<|im_end|>"}, "\n", {"token": "<|im_start|>"},