From ea045b0e5b7573b392a287e7eeb0beaf1bb73667 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 4 Aug 2023 23:14:28 +0800 Subject: [PATCH] support chatml safe encoding Former-commit-id: b4852f94065a11c8cd00ffa7e71ac0e0b2bf477a --- src/llmtuner/chat/stream_chat.py | 11 +- src/llmtuner/dsets/preprocess.py | 17 +- src/llmtuner/extras/template.py | 408 +++++++++++++++++++++---------- 3 files changed, 293 insertions(+), 143 deletions(-) diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 13168007..ef30a324 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -30,10 +30,11 @@ class ChatModel: ) -> Tuple[Dict[str, Any], int]: prefix = prefix or self.source_prefix - prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token) - inputs = self.tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.model.device) - prompt_length = len(inputs["input_ids"][0]) + prompt, _ = self.template.get_prompt( + tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix + ) + input_ids = torch.tensor([prompt], device=self.model.device) + prompt_length = len(input_ids[0]) do_sample = input_kwargs.pop("do_sample", None) temperature = input_kwargs.pop("temperature", None) @@ -45,7 +46,7 @@ class ChatModel: gen_kwargs = self.generating_args.to_dict() gen_kwargs.update(dict( - input_ids=inputs["input_ids"], + input_ids=input_ids, do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"], temperature=temperature or gen_kwargs["temperature"], top_p=top_p or gen_kwargs["top_p"], diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 0257b244..2482abe3 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -55,10 +55,7 @@ def preprocess_dataset( for query, response, history, prefix in construct_example(examples): input_ids, labels = [], [] - for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)): - source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0)) - target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False) - + for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] if len(target_ids) > data_args.max_target_length - 1: # eos token @@ -81,10 +78,7 @@ def preprocess_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} for query, response, history, prefix in construct_example(examples): - prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) - - source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) - target_ids = tokenizer.encode(text=response, add_special_tokens=True) + source_ids, target_ids = template.get_prompt(tokenizer, query, response, history, prefix) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] @@ -101,11 +95,8 @@ def preprocess_dataset( # build input pairs with format ` X Y1 ` and ` X Y2 ` model_inputs = {"accept_ids": [], "reject_ids": []} for query, response, history, prefix in construct_example(examples): - prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) - - source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) - accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False) - reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False) + source_ids, accept_ids = template.get_prompt(tokenizer, query, response[0], history, prefix) + source_ids, reject_ids = template.get_prompt(tokenizer, query, response[1], history, prefix) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index bb550058..46089413 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -1,82 +1,146 @@ -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from dataclasses import dataclass +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + @dataclass class Template: - prefix: str - prompt: str - sep: str - use_history: bool + prefix: List[Union[str, Dict[str, str]]] + prompt: List[Union[str, Dict[str, str]]] + sep: List[Union[str, Dict[str, str]]] stop_words: List[str] + use_history: bool def get_prompt( self, + tokenizer: "PreTrainedTokenizer", query: str, + resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "", - eos_token: Optional[str] = "" - ) -> str: + prefix: Optional[str] = None + ) -> Tuple[List[int], List[int]]: r""" - Returns a string containing prompt without response. + Returns a single pair of token ids representing prompt and response respectively. """ - return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix))) + prefix, history = self._format(query=query, resp=resp, history=history, prefix=prefix) + encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history) + prompt_ids = [] + for query_ids, resp_ids in encoded_pairs[:-1]: + prompt_ids = prompt_ids + query_ids + resp_ids + prompt_ids = prompt_ids + encoded_pairs[-1][0] + return prompt_ids, encoded_pairs[-1][1] def get_dialog( + self, + tokenizer: "PreTrainedTokenizer", + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + prefix: Optional[str] = None + ) -> List[Tuple[List[int], List[int]]]: + r""" + Returns multiple pairs of token ids representing prompts and responses respectively. + """ + prefix, history = self._format(query=query, resp=resp, history=history, prefix=prefix) + encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history) + return encoded_pairs + + def _format( self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "" - ) -> List[Tuple[str, str]]: + prefix: Optional[str] = None + ) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]: r""" - Returns a list containing prompt-response pairs. + Aligns inputs to a special format. """ - result = self._format_example(query, history, prefix) - result[-1][-1] = resp - return result - - def _format_example( - self, - query: str, - history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "" - ) -> List[Tuple[str, str]]: - prefix = prefix or self.prefix # use prefix if provided - prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix + prefix = [prefix] if prefix is not None else self.prefix # use prefix if provided + prefix = prefix + self.sep if prefix else [] # add separator for non-empty prefix history = history if (history and self.use_history) else [] - history = history + [(query, "")] - return [ - [(self.sep if i else prefix) + self.prompt.format(query=q), r] - for i, (q, r) in enumerate(history) - ] + history = history + [(query, resp)] + return prefix, history + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + prefix: List[Union[str, Dict[str, str]]], + history: List[Tuple[str, str]] + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + """ + encoded_pairs = [] + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + else: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((prefix_ids + query_ids, resp_ids)) + return encoded_pairs + + def _convert_inputs_to_ids( + self, + tokenizer: "PreTrainedTokenizer", + context: List[Union[str, Dict[str, str]]], + query: Optional[str] = "" + ) -> List[int]: + r""" + Converts context to token ids. + """ + token_ids = [] + for elem in context: + if isinstance(elem, str): + elem = elem.format(query=query) + token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False) + elif isinstance(elem, dict): + token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] + else: + raise NotImplementedError + return token_ids @dataclass class Llama2Template(Template): - def _format_example( + def _encode( self, - query: str, - history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "" - ) -> List[Tuple[str, str]]: - prefix = prefix or self.prefix # use prefix if provided - prefix = prefix if prefix.startswith("<>") else "<>\n{}\n<>\n\n".format(prefix) - history = history if (history and self.use_history) else [] - history = history + [(query, "")] - return [ - [(self.sep if i else "") + self.prompt.format(query=(q if i else prefix + q)), r] - for i, (q, r) in enumerate(history) - ] + tokenizer: "PreTrainedTokenizer", + prefix: List[Union[str, Dict[str, str]]], + history: List[Tuple[str, str]] + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + """ + encoded_pairs = [] + assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str." + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: + prefix_ids = [] + query = prefix[0] + query + else: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((prefix_ids + query_ids, resp_ids)) + return encoded_pairs templates: Dict[str, Template] = {} def register_template( - name: str, prefix: str, prompt: str, sep: str, use_history: bool, stop_words: List[str] + name: str, + prefix: List[Union[str, Dict[str, str]]], + prompt: List[Union[str, Dict[str, str]]], + sep: List[Union[str, Dict[str, str]]], + stop_words: List[str], + use_history: bool ) -> None: template_class = Llama2Template if name == "llama2" else Template templates[name] = template_class( @@ -99,11 +163,13 @@ Supports language model inference without histories. """ register_template( name="vanilla", - prefix="", - prompt="{query}", - sep="", - use_history=False, - stop_words=[] + prefix=[], + prompt=[ + "{query}" + ], + sep=[], + stop_words=[], + use_history=False ) @@ -112,12 +178,18 @@ Default template. """ register_template( name="default", - prefix="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", - prompt="Human: {query}\nAssistant: ", - sep="\n", - use_history=True, - stop_words=[] + prefix=[ + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ], + prompt=[ + "Human: {query}\nAssistant: " + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -128,18 +200,24 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf """ register_template( name="llama2", - prefix="<>\nYou are a helpful, respectful and honest assistant. " - "Always answer as helpfully as possible, while being safe. " - "Your answers should not include any harmful, unethical, " - "racist, sexist, toxic, dangerous, or illegal content. " - "Please ensure that your responses are socially unbiased and positive in nature.\n" - "If a question does not make any sense, or is not factually coherent, " - "explain why instead of answering something not correct. " - "If you don't know the answer to a question, please don't share false information.\n<>\n\n", - prompt="[INST] {query} [/INST] ", - sep="", - use_history=True, - stop_words=[] + prefix=[ + "<>\nYou are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information.\n<>\n\n" + ], + prompt=[ + "[INST] {query} [/INST] " + ], + sep=[ + {"token": ""} + ], + stop_words=[], + use_history=True ) @@ -149,12 +227,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff """ register_template( name="alpaca", - prefix="Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.", - prompt="### Instruction:\n{query}\n\n### Response:\n", - sep="\n\n", - use_history=True, - stop_words=[] + prefix=[ + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request." + ], + prompt=[ + "### Instruction:\n{query}\n\n### Response:\n" + ], + sep=[ + "\n\n" + ], + stop_words=[], + use_history=True ) @@ -164,12 +248,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 """ register_template( name="vicuna", - prefix="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", - prompt="USER: {query} ASSISTANT: ", - sep="", - use_history=True, - stop_words=[] + prefix=[ + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ], + prompt=[ + "USER: {query} ASSISTANT: " + ], + sep=[], + stop_words=[], + use_history=True ) @@ -178,11 +266,15 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B """ register_template( name="belle", - prefix="", - prompt="Human: {query}\n\nBelle: ", - sep="\n\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + "Human: {query}\n\nBelle: " + ], + sep=[ + "\n\n" + ], + stop_words=[], + use_history=True ) @@ -191,11 +283,15 @@ Supports: https://github.com/CVI-SZU/Linly """ register_template( name="linly", - prefix="", - prompt="User: {query}\nBot: ", - sep="\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + "User: {query}\nBot: " + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -204,11 +300,15 @@ Supports: https://github.com/Neutralzz/BiLLa """ register_template( name="billa", - prefix="", - prompt="Human: {query}\nAssistant: ", - sep="\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + "Human: {query}\nAssistant: " + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -217,11 +317,18 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 """ register_template( name="ziya", - prefix="", - prompt=":{query}\n:", - sep="\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + {"token": ""}, + ":{query}\n", + {"token": ""}, + ":" + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -230,12 +337,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b """ register_template( name="aquila", - prefix="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", - prompt="Human: {query}###Assistant: ", - sep="###", - use_history=True, - stop_words=[] + prefix=[ + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ], + prompt=[ + "Human: {query}###Assistant: " + ], + sep=[ + "###" + ], + stop_words=[], + use_history=True ) @@ -244,11 +357,23 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b """ register_template( name="intern", - prefix="", - prompt="<|User|>:{query}\n<|Bot|>:", - sep="\n", - use_history=True, - stop_words=[""] + prefix=[], + prompt=[ + {"token": "<|User|>"}, + ":{query}", + {"token": ""}, + "\n", + {"token": "<|Bot|>"}, + ":" + ], + sep=[ + {"token": ""}, + "\n" + ], + stop_words=[ + "" + ], + use_history=True ) @@ -257,11 +382,15 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat """ register_template( name="baichuan", - prefix="", - prompt="{query}", - sep="", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + {"token": ""}, + "{query}", + {"token": ""} + ], + sep=[], + stop_words=[], + use_history=True ) @@ -271,11 +400,25 @@ Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha """ register_template( name="starchat", - prefix="<|system|>\n", - prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n", - sep="<|end|>\n", - use_history=True, - stop_words=["<|end|>"] + prefix=[ + {"token": "<|system|>"}, + "\n" + ], + prompt=[ + {"token": "<|user|>"}, + "\n{query}", + {"token": "<|end|>"}, + "\n", + {"token": "<|assistant|>"} + ], + sep=[ + {"token": "<|end|>"}, + "\n" + ], + stop_words=[ + "<|end|>" + ], + use_history=True ) @@ -284,9 +427,24 @@ Supports: https://huggingface.co/Qwen/Qwen-7B-Chat """ register_template( name="chatml", - prefix="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", - prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n", - sep="<|im_end|>\n", - use_history=True, - stop_words=["<|im_end|>"] + prefix=[ + {"token": "<|im_start|>"}, + "system\nYou are a helpful assistant." + ], + prompt=[ + {"token": "<|im_start|>"}, + "user\n{query}", + {"token": "<|im_end|>"}, + "\n", + {"token": "<|im_start|>"}, + "assistant\n" + ], + sep=[ + {"token": "<|im_end|>"}, + "\n" + ], + stop_words=[ + "<|im_end|>" + ], + use_history=True )