diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 13168007..ef30a324 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -30,10 +30,11 @@ class ChatModel: ) -> Tuple[Dict[str, Any], int]: prefix = prefix or self.source_prefix - prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token) - inputs = self.tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.model.device) - prompt_length = len(inputs["input_ids"][0]) + prompt, _ = self.template.get_prompt( + tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix + ) + input_ids = torch.tensor([prompt], device=self.model.device) + prompt_length = len(input_ids[0]) do_sample = input_kwargs.pop("do_sample", None) temperature = input_kwargs.pop("temperature", None) @@ -45,7 +46,7 @@ class ChatModel: gen_kwargs = self.generating_args.to_dict() gen_kwargs.update(dict( - input_ids=inputs["input_ids"], + input_ids=input_ids, do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"], temperature=temperature or gen_kwargs["temperature"], top_p=top_p or gen_kwargs["top_p"], diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 0257b244..ad481ffd 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -30,7 +30,7 @@ def preprocess_dataset( yield query, response, history, prefix def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: - # build grouped texts with format ` X1 X2 X3 ...` (without ) + # build grouped texts with format `X1 X2 X3 ...` (without ) tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False) concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) @@ -55,20 +55,17 @@ def preprocess_dataset( for query, response, history, prefix in construct_example(examples): input_ids, labels = [], [] - for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)): - source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0)) - target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False) - + for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] - if len(target_ids) > data_args.max_target_length - 1: # eos token - target_ids = target_ids[:data_args.max_target_length - 1] + if len(target_ids) > data_args.max_target_length: + target_ids = target_ids[:data_args.max_target_length] - if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length: + if len(input_ids) + len(source_ids) + len(target_ids) > max_length: break - input_ids += source_ids + target_ids + [tokenizer.eos_token_id] - labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] + input_ids += source_ids + target_ids + labels += [IGNORE_INDEX] * len(source_ids) + target_ids model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) @@ -81,10 +78,7 @@ def preprocess_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} for query, response, history, prefix in construct_example(examples): - prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) - - source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) - target_ids = tokenizer.encode(text=response, add_special_tokens=True) + source_ids, target_ids = template.get_prompt(tokenizer, query, response, history, prefix) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] @@ -101,11 +95,8 @@ def preprocess_dataset( # build input pairs with format ` X Y1 ` and ` X Y2 ` model_inputs = {"accept_ids": [], "reject_ids": []} for query, response, history, prefix in construct_example(examples): - prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) - - source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) - accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False) - reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False) + source_ids, accept_ids = template.get_prompt(tokenizer, query, response[0], history, prefix) + source_ids, reject_ids = template.get_prompt(tokenizer, query, response[1], history, prefix) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index bb550058..4f8e6301 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -1,82 +1,153 @@ -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from dataclasses import dataclass +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + @dataclass class Template: - prefix: str - prompt: str - sep: str - use_history: bool + prefix: List[Union[str, Dict[str, str]]] + prompt: List[Union[str, Dict[str, str]]] + sep: List[Union[str, Dict[str, str]]] stop_words: List[str] + use_history: bool def get_prompt( self, + tokenizer: "PreTrainedTokenizer", query: str, + resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "", - eos_token: Optional[str] = "" - ) -> str: + prefix: Optional[str] = None + ) -> Tuple[List[int], List[int]]: r""" - Returns a string containing prompt without response. + Returns a single pair of token ids representing prompt and response respectively. """ - return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix))) + prefix, history = self._format(query=query, resp=resp, history=history, prefix=prefix) + encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history) + prompt_ids = [] + for query_ids, resp_ids in encoded_pairs[:-1]: + prompt_ids = prompt_ids + query_ids + resp_ids + prompt_ids = prompt_ids + encoded_pairs[-1][0] + return prompt_ids, encoded_pairs[-1][1] def get_dialog( + self, + tokenizer: "PreTrainedTokenizer", + query: str, + resp: str, + history: Optional[List[Tuple[str, str]]] = None, + prefix: Optional[str] = None + ) -> List[Tuple[List[int], List[int]]]: + r""" + Returns multiple pairs of token ids representing prompts and responses respectively. + """ + prefix, history = self._format(query=query, resp=resp, history=history, prefix=prefix) + encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history) + return encoded_pairs + + def _format( self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "" - ) -> List[Tuple[str, str]]: + prefix: Optional[str] = None + ) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]: r""" - Returns a list containing prompt-response pairs. + Aligns inputs to a special format. """ - result = self._format_example(query, history, prefix) - result[-1][-1] = resp - return result - - def _format_example( - self, - query: str, - history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "" - ) -> List[Tuple[str, str]]: - prefix = prefix or self.prefix # use prefix if provided - prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix + prefix = [prefix] if prefix is not None else self.prefix # use prefix if provided + prefix = prefix + self.sep if prefix else [] # add separator for non-empty prefix history = history if (history and self.use_history) else [] - history = history + [(query, "")] - return [ - [(self.sep if i else prefix) + self.prompt.format(query=q), r] - for i, (q, r) in enumerate(history) - ] + history = history + [(query, resp)] + return prefix, history + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + prefix: List[Union[str, Dict[str, str]]], + history: List[Tuple[str, str]] + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + """ + if tokenizer.bos_token and getattr(tokenizer, "add_bos_token", False): # bos token is optional + bos_token_id = [tokenizer.bos_token_id] + else: + bos_token_id = [] + eos_token_id = [tokenizer.eos_token_id] # eos token is required + encoded_pairs = [] + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + else: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id)) + return encoded_pairs + + def _convert_inputs_to_ids( + self, + tokenizer: "PreTrainedTokenizer", + context: List[Union[str, Dict[str, str]]], + query: Optional[str] = "" + ) -> List[int]: + r""" + Converts context to token ids. + """ + token_ids = [] + for elem in context: + if isinstance(elem, str): + subelems = elem.split("{{query}}") + if len(subelems) > 1: + elem = subelems[0] + query + subelems[1] + token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False) + elif isinstance(elem, dict): + token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] + else: + raise NotImplementedError + return token_ids @dataclass class Llama2Template(Template): - def _format_example( + def _encode( self, - query: str, - history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = "" - ) -> List[Tuple[str, str]]: - prefix = prefix or self.prefix # use prefix if provided - prefix = prefix if prefix.startswith("<>") else "<>\n{}\n<>\n\n".format(prefix) - history = history if (history and self.use_history) else [] - history = history + [(query, "")] - return [ - [(self.sep if i else "") + self.prompt.format(query=(q if i else prefix + q)), r] - for i, (q, r) in enumerate(history) - ] + tokenizer: "PreTrainedTokenizer", + prefix: List[Union[str, Dict[str, str]]], + history: List[Tuple[str, str]] + ) -> List[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + """ + encoded_pairs = [] + assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str." + for turn_idx, (query, resp) in enumerate(history): + if turn_idx == 0: + prefix_ids = [] + query = prefix[0] + query + else: + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) + resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) + encoded_pairs.append((prefix_ids + query_ids, resp_ids)) + return encoded_pairs templates: Dict[str, Template] = {} def register_template( - name: str, prefix: str, prompt: str, sep: str, use_history: bool, stop_words: List[str] + name: str, + prefix: List[Union[str, Dict[str, str]]], + prompt: List[Union[str, Dict[str, str]]], + sep: List[Union[str, Dict[str, str]]], + stop_words: List[str], + use_history: bool ) -> None: template_class = Llama2Template if name == "llama2" else Template templates[name] = template_class( @@ -99,11 +170,13 @@ Supports language model inference without histories. """ register_template( name="vanilla", - prefix="", - prompt="{query}", - sep="", - use_history=False, - stop_words=[] + prefix=[], + prompt=[ + "{{query}}" + ], + sep=[], + stop_words=[], + use_history=False ) @@ -112,12 +185,18 @@ Default template. """ register_template( name="default", - prefix="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", - prompt="Human: {query}\nAssistant: ", - sep="\n", - use_history=True, - stop_words=[] + prefix=[ + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ], + prompt=[ + "Human: {{query}}\nAssistant: " + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -128,18 +207,24 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf """ register_template( name="llama2", - prefix="<>\nYou are a helpful, respectful and honest assistant. " - "Always answer as helpfully as possible, while being safe. " - "Your answers should not include any harmful, unethical, " - "racist, sexist, toxic, dangerous, or illegal content. " - "Please ensure that your responses are socially unbiased and positive in nature.\n" - "If a question does not make any sense, or is not factually coherent, " - "explain why instead of answering something not correct. " - "If you don't know the answer to a question, please don't share false information.\n<>\n\n", - prompt="[INST] {query} [/INST] ", - sep="", - use_history=True, - stop_words=[] + prefix=[ + "<>\nYou are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information.\n<>\n\n" + ], + prompt=[ + "[INST] {{query}} [/INST] " + ], + sep=[ + {"token": ""} + ], + stop_words=[], + use_history=True ) @@ -149,12 +234,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff """ register_template( name="alpaca", - prefix="Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.", - prompt="### Instruction:\n{query}\n\n### Response:\n", - sep="\n\n", - use_history=True, - stop_words=[] + prefix=[ + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request." + ], + prompt=[ + "### Instruction:\n{{query}}\n\n### Response:\n" + ], + sep=[ + "\n\n" + ], + stop_words=[], + use_history=True ) @@ -164,12 +255,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 """ register_template( name="vicuna", - prefix="A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions.", - prompt="USER: {query} ASSISTANT: ", - sep="", - use_history=True, - stop_words=[] + prefix=[ + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ], + prompt=[ + "USER: {{query}} ASSISTANT: " + ], + sep=[], + stop_words=[], + use_history=True ) @@ -178,11 +273,15 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B """ register_template( name="belle", - prefix="", - prompt="Human: {query}\n\nBelle: ", - sep="\n\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + "Human: {{query}}\n\nBelle: " + ], + sep=[ + "\n\n" + ], + stop_words=[], + use_history=True ) @@ -191,11 +290,15 @@ Supports: https://github.com/CVI-SZU/Linly """ register_template( name="linly", - prefix="", - prompt="User: {query}\nBot: ", - sep="\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + "User: {{query}}\nBot: " + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -204,11 +307,15 @@ Supports: https://github.com/Neutralzz/BiLLa """ register_template( name="billa", - prefix="", - prompt="Human: {query}\nAssistant: ", - sep="\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + "Human: {{query}}\nAssistant: " + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -217,11 +324,18 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 """ register_template( name="ziya", - prefix="", - prompt=":{query}\n:", - sep="\n", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + {"token": ""}, + ":{{query}}\n", + {"token": ""}, + ":" + ], + sep=[ + "\n" + ], + stop_words=[], + use_history=True ) @@ -230,12 +344,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b """ register_template( name="aquila", - prefix="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", - prompt="Human: {query}###Assistant: ", - sep="###", - use_history=True, - stop_words=[] + prefix=[ + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ], + prompt=[ + "Human: {{query}}###Assistant: " + ], + sep=[ + "###" + ], + stop_words=[], + use_history=True ) @@ -244,11 +364,23 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b """ register_template( name="intern", - prefix="", - prompt="<|User|>:{query}\n<|Bot|>:", - sep="\n", - use_history=True, - stop_words=[""] + prefix=[], + prompt=[ + {"token": "<|User|>"}, + ":{{query}}", + {"token": ""}, + "\n", + {"token": "<|Bot|>"}, + ":" + ], + sep=[ + {"token": ""}, + "\n" + ], + stop_words=[ + "" + ], + use_history=True ) @@ -257,11 +389,15 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat """ register_template( name="baichuan", - prefix="", - prompt="{query}", - sep="", - use_history=True, - stop_words=[] + prefix=[], + prompt=[ + {"token": ""}, + "{{query}}", + {"token": ""} + ], + sep=[], + stop_words=[], + use_history=True ) @@ -271,11 +407,25 @@ Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha """ register_template( name="starchat", - prefix="<|system|>\n", - prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n", - sep="<|end|>\n", - use_history=True, - stop_words=["<|end|>"] + prefix=[ + {"token": "<|system|>"}, + "\n" + ], + prompt=[ + {"token": "<|user|>"}, + "\n{{query}}", + {"token": "<|end|>"}, + "\n", + {"token": "<|assistant|>"} + ], + sep=[ + {"token": "<|end|>"}, + "\n" + ], + stop_words=[ + "<|end|>" + ], + use_history=True ) @@ -284,9 +434,24 @@ Supports: https://huggingface.co/Qwen/Qwen-7B-Chat """ register_template( name="chatml", - prefix="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", - prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n", - sep="<|im_end|>\n", - use_history=True, - stop_words=["<|im_end|>"] + prefix=[ + {"token": "<|im_start|>"}, + "system\nYou are a helpful assistant." + ], + prompt=[ + {"token": "<|im_start|>"}, + "user\n{{query}}", + {"token": "<|im_end|>"}, + "\n", + {"token": "<|im_start|>"}, + "assistant\n" + ], + sep=[ + {"token": "<|im_end|>"}, + "\n" + ], + stop_words=[ + "<|im_end|>" + ], + use_history=True )