From 3021a01b7176aa0e4acb735df9a5540c8a9bcda8 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 17 Aug 2023 01:27:20 +0800 Subject: [PATCH] fix baichuan and intern template Former-commit-id: 892fd39373b816cf079e0decc9cb57dfb5565242 --- src/llmtuner/extras/template.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 25907382..5d5a03fb 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -20,7 +20,6 @@ class Template: sep: List[Union[str, Dict[str, str]]] stop_words: List[str] use_history: bool - bos_after_prefix: bool def encode_oneturn( self, @@ -75,12 +74,15 @@ class Template: self, tokenizer: "PreTrainedTokenizer" ) -> Tuple[List[int], List[int]]: - if tokenizer.bos_token_id: + if ( + tokenizer.bos_token_id is not None + and getattr(tokenizer, "add_bos_token", True) + ): # baichuan-13b has no bos token bos_ids = [tokenizer.bos_token_id] else: bos_ids = [] # bos token is optional - if tokenizer.eos_token_id: + if tokenizer.eos_token_id is not None: eos_ids = [tokenizer.eos_token_id] else: raise ValueError("EOS token is required.") @@ -105,10 +107,7 @@ class Template: if turn_idx == 0: prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system) if len(prefix_ids) != 0: # has prefix - if self.bos_after_prefix: - prefix_ids = prefix_ids + bos_ids + sep_ids - else: - prefix_ids = bos_ids + prefix_ids + sep_ids + prefix_ids = bos_ids + prefix_ids + sep_ids else: prefix_ids = bos_ids else: @@ -185,8 +184,7 @@ def register_template( system: str, sep: List[Union[str, Dict[str, str]]], stop_words: Optional[List[str]] = [], - use_history: Optional[bool] = True, - bos_after_prefix: Optional[bool] = False + use_history: Optional[bool] = True ) -> None: template_class = Llama2Template if "llama2" in name else Template templates[name] = template_class( @@ -195,8 +193,7 @@ def register_template( system=system, sep=sep, stop_words=stop_words, - use_history=use_history, - bos_after_prefix=bos_after_prefix + use_history=use_history ) @@ -208,7 +205,6 @@ def get_template_and_fix_tokenizer( assert template is not None, "Template {} does not exist.".format(name) additional_special_tokens = template.stop_words - if len(template.stop_words): # inplace method if tokenizer.eos_token_id is not None: additional_special_tokens.append(tokenizer.eos_token) @@ -468,6 +464,7 @@ register_template( "\n" ], stop_words=[ + "", # internlm cannot replace eos token "" ] ) @@ -490,8 +487,7 @@ register_template( sep=[], stop_words=[ "" # user token - ], - bos_after_prefix=True + ] )