fix baichuan and intern template

Former-commit-id: 892fd39373b816cf079e0decc9cb57dfb5565242
This commit is contained in:
hiyouga 2023-08-17 01:27:20 +08:00
parent 048f99354f
commit 3021a01b71

View File

@ -20,7 +20,6 @@ class Template:
sep: List[Union[str, Dict[str, str]]] sep: List[Union[str, Dict[str, str]]]
stop_words: List[str] stop_words: List[str]
use_history: bool use_history: bool
bos_after_prefix: bool
def encode_oneturn( def encode_oneturn(
self, self,
@ -75,12 +74,15 @@ class Template:
self, self,
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if tokenizer.bos_token_id: if (
tokenizer.bos_token_id is not None
and getattr(tokenizer, "add_bos_token", True)
): # baichuan-13b has no bos token
bos_ids = [tokenizer.bos_token_id] bos_ids = [tokenizer.bos_token_id]
else: else:
bos_ids = [] # bos token is optional bos_ids = [] # bos token is optional
if tokenizer.eos_token_id: if tokenizer.eos_token_id is not None:
eos_ids = [tokenizer.eos_token_id] eos_ids = [tokenizer.eos_token_id]
else: else:
raise ValueError("EOS token is required.") raise ValueError("EOS token is required.")
@ -105,9 +107,6 @@ class Template:
if turn_idx == 0: if turn_idx == 0:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system) prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
if len(prefix_ids) != 0: # has prefix if len(prefix_ids) != 0: # has prefix
if self.bos_after_prefix:
prefix_ids = prefix_ids + bos_ids + sep_ids
else:
prefix_ids = bos_ids + prefix_ids + sep_ids prefix_ids = bos_ids + prefix_ids + sep_ids
else: else:
prefix_ids = bos_ids prefix_ids = bos_ids
@ -185,8 +184,7 @@ def register_template(
system: str, system: str,
sep: List[Union[str, Dict[str, str]]], sep: List[Union[str, Dict[str, str]]],
stop_words: Optional[List[str]] = [], stop_words: Optional[List[str]] = [],
use_history: Optional[bool] = True, use_history: Optional[bool] = True
bos_after_prefix: Optional[bool] = False
) -> None: ) -> None:
template_class = Llama2Template if "llama2" in name else Template template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class( templates[name] = template_class(
@ -195,8 +193,7 @@ def register_template(
system=system, system=system,
sep=sep, sep=sep,
stop_words=stop_words, stop_words=stop_words,
use_history=use_history, use_history=use_history
bos_after_prefix=bos_after_prefix
) )
@ -208,7 +205,6 @@ def get_template_and_fix_tokenizer(
assert template is not None, "Template {} does not exist.".format(name) assert template is not None, "Template {} does not exist.".format(name)
additional_special_tokens = template.stop_words additional_special_tokens = template.stop_words
if len(template.stop_words): # inplace method if len(template.stop_words): # inplace method
if tokenizer.eos_token_id is not None: if tokenizer.eos_token_id is not None:
additional_special_tokens.append(tokenizer.eos_token) additional_special_tokens.append(tokenizer.eos_token)
@ -468,6 +464,7 @@ register_template(
"\n" "\n"
], ],
stop_words=[ stop_words=[
"</s>", # internlm cannot replace eos token
"<eoa>" "<eoa>"
] ]
) )
@ -490,8 +487,7 @@ register_template(
sep=[], sep=[],
stop_words=[ stop_words=[
"<reserved_102>" # user token "<reserved_102>" # user token
], ]
bos_after_prefix=True
) )