From ea045b0e5b7573b392a287e7eeb0beaf1bb73667 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 4 Aug 2023 23:14:28 +0800 Subject: [PATCH 1/3] support chatml safe encoding Former-commit-id: b4852f94065a11c8cd00ffa7e71ac0e0b2bf477a --- src/llmtuner/chat/stream_chat.py | 11 +- src/llmtuner/dsets/preprocess.py | 17 +- src/llmtuner/extras/template.py | 408 +++++++++++++++++++++---------- 3 files changed, 293 insertions(+), 143 deletions(-) diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 13168007..ef30a324 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -30,10 +30,11 @@ class ChatModel: ) -> Tuple[Dict[str, Any], int]: prefix = prefix or self.source_prefix - prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token) - inputs = self.tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.model.device) - prompt_length = len(inputs["input_ids"][0]) + prompt, _ = self.template.get_prompt( + tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix + ) + input_ids = torch.tensor([prompt], device=self.model.device) + prompt_length = len(input_ids[0]) do_sample = input_kwargs.pop("do_sample", None) temperature = input_kwargs.pop("temperature", None) @@ -45,7 +46,7 @@ class ChatModel: gen_kwargs = self.generating_args.to_dict() gen_kwargs.update(dict( - input_ids=inputs["input_ids"], + input_ids=input_ids, do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"], temperature=temperature or gen_kwargs["temperature"], top_p=top_p or gen_kwargs["top_p"], diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 0257b244..2482abe3 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -55,10 +55,7 @@ def preprocess_dataset( for query, response, history, prefix in construct_example(examples): input_ids, labels = [], [] - for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)): - source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0)) - target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False) - + for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] if len(target_ids) > data_args.max_target_length - 1: # eos token @@ -81,10 +78,7 @@ def preprocess_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} for query, response, history, prefix in construct_example(examples): - prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) - - source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) - target_ids = tokenizer.encode(text=response, add_special_tokens=True) + source_ids, target_ids = template.get_prompt(tokenizer, query, response, history, prefix) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] @@ -101,11 +95,8 @@ def preprocess_dataset( # build input pairs with format ` X Y1 ` and ` X Y2 ` model_inputs = {"accept_ids": [], "reject_ids": []} for query, response, history, prefix in construct_example(examples): - prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) - - source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) - accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False) - reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False) + source_ids, accept_ids = template.get_prompt(tokenizer, query, response[0], history, prefix) + source_ids, reject_ids = template.get_prompt(tokenizer, query, response[1], history, prefix) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index bb550058..46089413 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -1,82 +1,146 @@ -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from dataclasses import dataclass +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + @dataclass class Template: - prefix: str - prompt: str - sep: str - use_history: bool + prefix: List[Union[str, Dict[str, str]]] + prompt: List[Union[str, Dict[str, str]]] + sep: List[Union[str, Dict[str, str]]] stop_words: List[str] + use_history: bool def get_prompt( self, + tokenizer: "PreTrainedTokenizer", query: str, + resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "", - eos_token: Optional[str] = "" - ) -> str: + prefix: Optional[str] = None + ) -> Tuple[List[int], List[int]]: r""" - Returns a string containing prompt without response. + Returns a single pair of token ids representing prompt and response respectively. """ - return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix))) + prefix, history = self._format(query=query, resp=resp, history=history, prefix=prefix) + encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history) + prompt_ids = [] + for query_ids, resp_ids in encoded_pairs[:-1]: + prompt_ids = prompt_ids + query_ids + resp_ids + prompt_ids = prompt_ids + encoded_pairs[-1][0] + return prompt_ids, encoded_pairs[-1][1] def get_dialog( + self, + tokenizer: "PreTrainedTokenizer", + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + prefix: Optional[str] = None + ) -> List[Tuple[List[int], List[int]]]: + r""" + Returns multiple pairs of token ids representing prompts and responses respectively. + """ + prefix, history = self._format(query=query, resp=resp, history=history, prefix=prefix) + encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history) + return encoded_pairs + + def _format( self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "" - ) -> List[Tuple[str, str]]: + prefix: Optional[str] = None + ) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]: r""" - Returns a list containing prompt-response pairs. + Aligns inputs to a special format. """ - result = self._format_example(query, history, prefix) - result[-1][-1] = resp - return result - - def _format_example( - self, - query: str, - history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "" - ) -> List[Tuple[str, str]]: - prefix = prefix or self.prefix # use prefix if provided - prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix + prefix = [prefix] if prefix is not None else self.prefix # use prefix if provided + prefix = prefix + self.sep if prefix else [] # add separator for non-empty prefix history = history if (history and self.use_history) else [] - history = history + [(query, "")] - return [ - [(self.sep if i else prefix) + self.prompt.format(query=q), r] - for i, (q, r) in enumerate(history) - ] + history = history + [(query, resp)] + return prefix, history + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + prefix: List[Union[str, Dict[str, str]]], + history: List[Tuple[str, str]] + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + """ + encoded_pairs = [] + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + else: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((prefix_ids + query_ids, resp_ids)) + return encoded_pairs + + def _convert_inputs_to_ids( + self, + tokenizer: "PreTrainedTokenizer", + context: List[Union[str, Dict[str, str]]], + query: Optional[str] = "" + ) -> List[int]: + r""" + Converts context to token ids. + """ + token_ids = [] + for elem in context: + if isinstance(elem, str): + elem = elem.format(query=query) + token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False) + elif isinstance(elem, dict): + token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] + else: + raise NotImplementedError + return token_ids @dataclass class Llama2Template(Template): - def _format_example( + def _encode( self, - query: str, - history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "" - ) -> List[Tuple[str, str]]: - prefix = prefix or self.prefix # use prefix if provided - prefix = prefix if prefix.startswith("<>") else "<>\n{}\n<>\n\n".format(prefix) - history = history if (history and self.use_history) else [] - history = history + [(query, "")] - return [ - [(self.sep if i else "") + self.prompt.format(query=(q if i else prefix + q)), r] - for i, (q, r) in enumerate(history) - ] + tokenizer: "PreTrainedTokenizer", + prefix: List[Union[str, Dict[str, str]]], + history: List[Tuple[str, str]] + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + """ + encoded_pairs = [] + assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str." + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: + prefix_ids = [] + query = prefix[0] + query + else: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((prefix_ids + query_ids, resp_ids)) + return encoded_pairs templates: Dict[str, Template] = {} def register_template( - name: str, prefix: str, prompt: str, sep: str, use_history: bool, stop_words: List[str] + name: str, + prefix: List[Union[str, Dict[str, str]]], + prompt: List[Union[str, Dict[str, str]]], + sep: List[Union[str, Dict[str, str]]], + stop_words: List[str], + use_history: bool ) -> None: template_class = Llama2Template if name == "llama2" else Template templates[name] = template_class( @@ -99,11 +163,13 @@ Supports language model inference without histories. """ register_template( name="vanilla", - prefix="", - prompt="{query}", - sep="", - use_history=False, - stop_words=[] + prefix=[], + prompt=[ + "{query}" + ], + sep=[], + stop_words=[], + use_history=False ) @@ -112,12 +178,18 @@ Default template. """ register_template( name="default", - prefix="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", - prompt="Human: {query}\nAssistant: ", - sep="\n", - use_history=True, - stop_words=[] + prefix=[ + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ], + prompt=[ + "Human: {query}\nAssistant: " + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -128,18 +200,24 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf """ register_template( name="llama2", - prefix="<>\nYou are a helpful, respectful and honest assistant. " - "Always answer as helpfully as possible, while being safe. " - "Your answers should not include any harmful, unethical, " - "racist, sexist, toxic, dangerous, or illegal content. " - "Please ensure that your responses are socially unbiased and positive in nature.\n" - "If a question does not make any sense, or is not factually coherent, " - "explain why instead of answering something not correct. " - "If you don't know the answer to a question, please don't share false information.\n<>\n\n", - prompt="[INST] {query} [/INST] ", - sep="", - use_history=True, - stop_words=[] + prefix=[ + "<>\nYou are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information.\n<>\n\n" + ], + prompt=[ + "[INST] {query} [/INST] " + ], + sep=[ + {"token": ""} + ], + stop_words=[], + use_history=True ) @@ -149,12 +227,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff """ register_template( name="alpaca", - prefix="Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.", - prompt="### Instruction:\n{query}\n\n### Response:\n", - sep="\n\n", - use_history=True, - stop_words=[] + prefix=[ + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request." + ], + prompt=[ + "### Instruction:\n{query}\n\n### Response:\n" + ], + sep=[ + "\n\n" + ], + stop_words=[], + use_history=True ) @@ -164,12 +248,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 """ register_template( name="vicuna", - prefix="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", - prompt="USER: {query} ASSISTANT: ", - sep="", - use_history=True, - stop_words=[] + prefix=[ + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ], + prompt=[ + "USER: {query} ASSISTANT: " + ], + sep=[], + stop_words=[], + use_history=True ) @@ -178,11 +266,15 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B """ register_template( name="belle", - prefix="", - prompt="Human: {query}\n\nBelle: ", - sep="\n\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + "Human: {query}\n\nBelle: " + ], + sep=[ + "\n\n" + ], + stop_words=[], + use_history=True ) @@ -191,11 +283,15 @@ Supports: https://github.com/CVI-SZU/Linly """ register_template( name="linly", - prefix="", - prompt="User: {query}\nBot: ", - sep="\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + "User: {query}\nBot: " + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -204,11 +300,15 @@ Supports: https://github.com/Neutralzz/BiLLa """ register_template( name="billa", - prefix="", - prompt="Human: {query}\nAssistant: ", - sep="\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + "Human: {query}\nAssistant: " + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -217,11 +317,18 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 """ register_template( name="ziya", - prefix="", - prompt=":{query}\n:", - sep="\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + {"token": ""}, + ":{query}\n", + {"token": ""}, + ":" + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -230,12 +337,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b """ register_template( name="aquila", - prefix="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", - prompt="Human: {query}###Assistant: ", - sep="###", - use_history=True, - stop_words=[] + prefix=[ + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ], + prompt=[ + "Human: {query}###Assistant: " + ], + sep=[ + "###" + ], + stop_words=[], + use_history=True ) @@ -244,11 +357,23 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b """ register_template( name="intern", - prefix="", - prompt="<|User|>:{query}\n<|Bot|>:", - sep="\n", - use_history=True, - stop_words=[""] + prefix=[], + prompt=[ + {"token": "<|User|>"}, + ":{query}", + {"token": ""}, + "\n", + {"token": "<|Bot|>"}, + ":" + ], + sep=[ + {"token": ""}, + "\n" + ], + stop_words=[ + "" + ], + use_history=True ) @@ -257,11 +382,15 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat """ register_template( name="baichuan", - prefix="", - prompt="{query}", - sep="", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + {"token": ""}, + "{query}", + {"token": ""} + ], + sep=[], + stop_words=[], + use_history=True ) @@ -271,11 +400,25 @@ Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha """ register_template( name="starchat", - prefix="<|system|>\n", - prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n", - sep="<|end|>\n", - use_history=True, - stop_words=["<|end|>"] + prefix=[ + {"token": "<|system|>"}, + "\n" + ], + prompt=[ + {"token": "<|user|>"}, + "\n{query}", + {"token": "<|end|>"}, + "\n", + {"token": "<|assistant|>"} + ], + sep=[ + {"token": "<|end|>"}, + "\n" + ], + stop_words=[ + "<|end|>" + ], + use_history=True ) @@ -284,9 +427,24 @@ Supports: https://huggingface.co/Qwen/Qwen-7B-Chat """ register_template( name="chatml", - prefix="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", - prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n", - sep="<|im_end|>\n", - use_history=True, - stop_words=["<|im_end|>"] + prefix=[ + {"token": "<|im_start|>"}, + "system\nYou are a helpful assistant." + ], + prompt=[ + {"token": "<|im_start|>"}, + "user\n{query}", + {"token": "<|im_end|>"}, + "\n", + {"token": "<|im_start|>"}, + "assistant\n" + ], + sep=[ + {"token": "<|im_end|>"}, + "\n" + ], + stop_words=[ + "<|im_end|>" + ], + use_history=True ) From dbb284b5a268f91a802278cc56dedb0c08b91f92 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 4 Aug 2023 23:27:55 +0800 Subject: [PATCH 2/3] fix encode Former-commit-id: 8172ad1b5e3fa0b224d761ce6069d0db4397da2d --- src/llmtuner/dsets/preprocess.py | 2 +- src/llmtuner/extras/template.py | 34 +++++++++++++++++--------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 2482abe3..0f28ce1e 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -55,7 +55,7 @@ def preprocess_dataset( for query, response, history, prefix in construct_example(examples): input_ids, labels = [], [] - for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): + for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): # TODO: fix bos if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] if len(target_ids) > data_args.max_target_length - 1: # eos token diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 46089413..4ffd569a 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -29,7 +29,7 @@ class Template: encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history) prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: - prompt_ids = prompt_ids + query_ids + resp_ids + prompt_ids = prompt_ids + query_ids + resp_ids + [tokenizer.eos_token_id] prompt_ids = prompt_ids + encoded_pairs[-1][0] return prompt_ids, encoded_pairs[-1][1] @@ -96,7 +96,9 @@ class Template: token_ids = [] for elem in context: if isinstance(elem, str): - elem = elem.format(query=query) + subelems = elem.split("{{query}}") + if len(subelems) > 1: + elem = subelems[0] + query + subelems[1] token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False) elif isinstance(elem, dict): token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] @@ -165,7 +167,7 @@ register_template( name="vanilla", prefix=[], prompt=[ - "{query}" + "{{query}}" ], sep=[], stop_words=[], @@ -183,7 +185,7 @@ register_template( "The assistant gives helpful, detailed, and polite answers to the user's questions." ], prompt=[ - "Human: {query}\nAssistant: " + "Human: {{query}}\nAssistant: " ], sep=[ "\n" @@ -211,7 +213,7 @@ register_template( "If you don't know the answer to a question, please don't share false information.\n<>\n\n" ], prompt=[ - "[INST] {query} [/INST] " + "[INST] {{query}} [/INST] " ], sep=[ {"token": ""} @@ -232,7 +234,7 @@ register_template( "Write a response that appropriately completes the request." ], prompt=[ - "### Instruction:\n{query}\n\n### Response:\n" + "### Instruction:\n{{query}}\n\n### Response:\n" ], sep=[ "\n\n" @@ -253,7 +255,7 @@ register_template( "The assistant gives helpful, detailed, and polite answers to the user's questions." ], prompt=[ - "USER: {query} ASSISTANT: " + "USER: {{query}} ASSISTANT: " ], sep=[], stop_words=[], @@ -268,7 +270,7 @@ register_template( name="belle", prefix=[], prompt=[ - "Human: {query}\n\nBelle: " + "Human: {{query}}\n\nBelle: " ], sep=[ "\n\n" @@ -285,7 +287,7 @@ register_template( name="linly", prefix=[], prompt=[ - "User: {query}\nBot: " + "User: {{query}}\nBot: " ], sep=[ "\n" @@ -302,7 +304,7 @@ register_template( name="billa", prefix=[], prompt=[ - "Human: {query}\nAssistant: " + "Human: {{query}}\nAssistant: " ], sep=[ "\n" @@ -320,7 +322,7 @@ register_template( prefix=[], prompt=[ {"token": ""}, - ":{query}\n", + ":{{query}}\n", {"token": ""}, ":" ], @@ -342,7 +344,7 @@ register_template( "The assistant gives helpful, detailed, and polite answers to the human's questions." ], prompt=[ - "Human: {query}###Assistant: " + "Human: {{query}}###Assistant: " ], sep=[ "###" @@ -360,7 +362,7 @@ register_template( prefix=[], prompt=[ {"token": "<|User|>"}, - ":{query}", + ":{{query}}", {"token": ""}, "\n", {"token": "<|Bot|>"}, @@ -385,7 +387,7 @@ register_template( prefix=[], prompt=[ {"token": ""}, - "{query}", + "{{query}}", {"token": ""} ], sep=[], @@ -406,7 +408,7 @@ register_template( ], prompt=[ {"token": "<|user|>"}, - "\n{query}", + "\n{{query}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"} @@ -433,7 +435,7 @@ register_template( ], prompt=[ {"token": "<|im_start|>"}, - "user\n{query}", + "user\n{{query}}", {"token": "<|im_end|>"}, "\n", {"token": "<|im_start|>"}, From 65369ecf48420fb8e4c2cc4ea72e5dbc286dd5f5 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 4 Aug 2023 23:55:57 +0800 Subject: [PATCH 3/3] fix bos and eos token Former-commit-id: d87c8fd8ab84c9f58c0b1f3fb4ad0adf98b25715 --- src/llmtuner/dsets/preprocess.py | 14 +++++++------- src/llmtuner/extras/template.py | 9 +++++++-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 0f28ce1e..ad481ffd 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -30,7 +30,7 @@ def preprocess_dataset( yield query, response, history, prefix def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: - # build grouped texts with format ` X1 X2 X3 ...` (without ) + # build grouped texts with format `X1 X2 X3 ...` (without ) tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False) concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) @@ -55,17 +55,17 @@ def preprocess_dataset( for query, response, history, prefix in construct_example(examples): input_ids, labels = [], [] - for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): # TODO: fix bos + for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] - if len(target_ids) > data_args.max_target_length - 1: # eos token - target_ids = target_ids[:data_args.max_target_length - 1] + if len(target_ids) > data_args.max_target_length: + target_ids = target_ids[:data_args.max_target_length] - if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length: + if len(input_ids) + len(source_ids) + len(target_ids) > max_length: break - input_ids += source_ids + target_ids + [tokenizer.eos_token_id] - labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] + input_ids += source_ids + target_ids + labels += [IGNORE_INDEX] * len(source_ids) + target_ids model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 4ffd569a..4f8e6301 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -29,7 +29,7 @@ class Template: encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history) prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: - prompt_ids = prompt_ids + query_ids + resp_ids + [tokenizer.eos_token_id] + prompt_ids = prompt_ids + query_ids + resp_ids prompt_ids = prompt_ids + encoded_pairs[-1][0] return prompt_ids, encoded_pairs[-1][1] @@ -73,6 +73,11 @@ class Template: r""" Encodes formatted inputs to pairs of token ids. """ + if tokenizer.bos_token and getattr(tokenizer, "add_bos_token", False): # bos token is optional + bos_token_id = [tokenizer.bos_token_id] + else: + bos_token_id = [] + eos_token_id = [tokenizer.eos_token_id] # eos token is required encoded_pairs = [] for turn_idx, (query, resp) in enumerate(history): if turn_idx == 0: @@ -81,7 +86,7 @@ class Template: prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) - encoded_pairs.append((prefix_ids + query_ids, resp_ids)) + encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id)) return encoded_pairs def _convert_inputs_to_ids(