mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix baichuan and intern template
Former-commit-id: 892fd39373b816cf079e0decc9cb57dfb5565242
This commit is contained in:
parent
048f99354f
commit
3021a01b71
@ -20,7 +20,6 @@ class Template:
|
|||||||
sep: List[Union[str, Dict[str, str]]]
|
sep: List[Union[str, Dict[str, str]]]
|
||||||
stop_words: List[str]
|
stop_words: List[str]
|
||||||
use_history: bool
|
use_history: bool
|
||||||
bos_after_prefix: bool
|
|
||||||
|
|
||||||
def encode_oneturn(
|
def encode_oneturn(
|
||||||
self,
|
self,
|
||||||
@ -75,12 +74,15 @@ class Template:
|
|||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer"
|
tokenizer: "PreTrainedTokenizer"
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
if tokenizer.bos_token_id:
|
if (
|
||||||
|
tokenizer.bos_token_id is not None
|
||||||
|
and getattr(tokenizer, "add_bos_token", True)
|
||||||
|
): # baichuan-13b has no bos token
|
||||||
bos_ids = [tokenizer.bos_token_id]
|
bos_ids = [tokenizer.bos_token_id]
|
||||||
else:
|
else:
|
||||||
bos_ids = [] # bos token is optional
|
bos_ids = [] # bos token is optional
|
||||||
|
|
||||||
if tokenizer.eos_token_id:
|
if tokenizer.eos_token_id is not None:
|
||||||
eos_ids = [tokenizer.eos_token_id]
|
eos_ids = [tokenizer.eos_token_id]
|
||||||
else:
|
else:
|
||||||
raise ValueError("EOS token is required.")
|
raise ValueError("EOS token is required.")
|
||||||
@ -105,10 +107,7 @@ class Template:
|
|||||||
if turn_idx == 0:
|
if turn_idx == 0:
|
||||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
|
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
|
||||||
if len(prefix_ids) != 0: # has prefix
|
if len(prefix_ids) != 0: # has prefix
|
||||||
if self.bos_after_prefix:
|
prefix_ids = bos_ids + prefix_ids + sep_ids
|
||||||
prefix_ids = prefix_ids + bos_ids + sep_ids
|
|
||||||
else:
|
|
||||||
prefix_ids = bos_ids + prefix_ids + sep_ids
|
|
||||||
else:
|
else:
|
||||||
prefix_ids = bos_ids
|
prefix_ids = bos_ids
|
||||||
else:
|
else:
|
||||||
@ -185,8 +184,7 @@ def register_template(
|
|||||||
system: str,
|
system: str,
|
||||||
sep: List[Union[str, Dict[str, str]]],
|
sep: List[Union[str, Dict[str, str]]],
|
||||||
stop_words: Optional[List[str]] = [],
|
stop_words: Optional[List[str]] = [],
|
||||||
use_history: Optional[bool] = True,
|
use_history: Optional[bool] = True
|
||||||
bos_after_prefix: Optional[bool] = False
|
|
||||||
) -> None:
|
) -> None:
|
||||||
template_class = Llama2Template if "llama2" in name else Template
|
template_class = Llama2Template if "llama2" in name else Template
|
||||||
templates[name] = template_class(
|
templates[name] = template_class(
|
||||||
@ -195,8 +193,7 @@ def register_template(
|
|||||||
system=system,
|
system=system,
|
||||||
sep=sep,
|
sep=sep,
|
||||||
stop_words=stop_words,
|
stop_words=stop_words,
|
||||||
use_history=use_history,
|
use_history=use_history
|
||||||
bos_after_prefix=bos_after_prefix
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -208,7 +205,6 @@ def get_template_and_fix_tokenizer(
|
|||||||
assert template is not None, "Template {} does not exist.".format(name)
|
assert template is not None, "Template {} does not exist.".format(name)
|
||||||
|
|
||||||
additional_special_tokens = template.stop_words
|
additional_special_tokens = template.stop_words
|
||||||
|
|
||||||
if len(template.stop_words): # inplace method
|
if len(template.stop_words): # inplace method
|
||||||
if tokenizer.eos_token_id is not None:
|
if tokenizer.eos_token_id is not None:
|
||||||
additional_special_tokens.append(tokenizer.eos_token)
|
additional_special_tokens.append(tokenizer.eos_token)
|
||||||
@ -468,6 +464,7 @@ register_template(
|
|||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
|
"</s>", # internlm cannot replace eos token
|
||||||
"<eoa>"
|
"<eoa>"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -490,8 +487,7 @@ register_template(
|
|||||||
sep=[],
|
sep=[],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
"<reserved_102>" # user token
|
"<reserved_102>" # user token
|
||||||
],
|
]
|
||||||
bos_after_prefix=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user