diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 459382bf..6d57698d 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -21,6 +21,7 @@ class Template: stop_words: List[str] use_history: bool efficient_eos: bool + replace_eos: bool def encode_oneturn( self, @@ -38,7 +39,8 @@ class Template: prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: prompt_ids = prompt_ids + query_ids + resp_ids - prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1] + prompt_ids = prompt_ids + encoded_pairs[-1][0] + answer_ids = encoded_pairs[-1][1] return prompt_ids, answer_ids def encode_multiturn( @@ -77,13 +79,13 @@ class Template: ) -> Tuple[List[int], List[int]]: if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True): bos_ids = [tokenizer.bos_token_id] - else: # baichuan, qwen and gpt2 models have no bos token + else: # baichuan, gpt2, qwen, yi models have no bos token bos_ids = [] if tokenizer.eos_token_id is None: raise ValueError("EOS token is required.") - if self.efficient_eos: # used in baichuan, qwen, chatglm, etc. + if self.efficient_eos: eos_ids = [] else: eos_ids = [tokenizer.eos_token_id] @@ -187,9 +189,10 @@ def register_template( sep: List[Union[str, Dict[str, str]]], stop_words: Optional[List[str]] = [], use_history: Optional[bool] = True, - efficient_eos: Optional[bool] = False + efficient_eos: Optional[bool] = False, + replace_eos: Optional[bool] = False ) -> None: - template_class = Llama2Template if "llama2" in name else Template + template_class = Llama2Template if name.startswith("llama2") else Template templates[name] = template_class( prefix=prefix, prompt=prompt, @@ -197,7 +200,8 @@ def register_template( sep=sep, stop_words=stop_words, use_history=use_history, - efficient_eos=efficient_eos + efficient_eos=efficient_eos, + replace_eos=replace_eos ) @@ -213,15 +217,26 @@ def get_template_and_fix_tokenizer( tokenizer.pad_token = tokenizer.eos_token logger.info("Add pad token: {}".format(tokenizer.pad_token)) - if name is None: + if name is None: # for pre-training return None template = templates.get(name, None) assert template is not None, "Template {} does not exist.".format(name) - tokenizer.add_special_tokens( - dict(additional_special_tokens=template.stop_words), - replace_additional_special_tokens=False - ) + + if template.replace_eos: + if not template.stop_words: + raise ValueError("Stop words are required to replace the EOS token.") + + tokenizer.eos_token = template.stop_words.pop(0) + logger.info("Replace eos token: {}".format(tokenizer.eos_token)) + + if template.stop_words: + tokenizer.add_special_tokens( + dict(additional_special_tokens=template.stop_words), + replace_additional_special_tokens=False + ) + logger.info("Add {} to stop words.".format(",".join(template.stop_words))) + return template @@ -732,12 +747,12 @@ register_template( ], system="", sep=[ - "<|im_end|>\n" + "\n" ], stop_words=[ "<|im_end|>" ], - efficient_eos=True + replace_eos=True )