mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix yi template #1895
Former-commit-id: 5af8841c4f6c97df522d2cf4e283d5ef0af21a18
This commit is contained in:
parent
23a875a8b1
commit
622d31e398
@ -21,6 +21,7 @@ class Template:
|
|||||||
stop_words: List[str]
|
stop_words: List[str]
|
||||||
use_history: bool
|
use_history: bool
|
||||||
efficient_eos: bool
|
efficient_eos: bool
|
||||||
|
replace_eos: bool
|
||||||
|
|
||||||
def encode_oneturn(
|
def encode_oneturn(
|
||||||
self,
|
self,
|
||||||
@ -38,7 +39,8 @@ class Template:
|
|||||||
prompt_ids = []
|
prompt_ids = []
|
||||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||||
prompt_ids = prompt_ids + query_ids + resp_ids
|
prompt_ids = prompt_ids + query_ids + resp_ids
|
||||||
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
|
prompt_ids = prompt_ids + encoded_pairs[-1][0]
|
||||||
|
answer_ids = encoded_pairs[-1][1]
|
||||||
return prompt_ids, answer_ids
|
return prompt_ids, answer_ids
|
||||||
|
|
||||||
def encode_multiturn(
|
def encode_multiturn(
|
||||||
@ -77,13 +79,13 @@ class Template:
|
|||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
|
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
|
||||||
bos_ids = [tokenizer.bos_token_id]
|
bos_ids = [tokenizer.bos_token_id]
|
||||||
else: # baichuan, qwen and gpt2 models have no bos token
|
else: # baichuan, gpt2, qwen, yi models have no bos token
|
||||||
bos_ids = []
|
bos_ids = []
|
||||||
|
|
||||||
if tokenizer.eos_token_id is None:
|
if tokenizer.eos_token_id is None:
|
||||||
raise ValueError("EOS token is required.")
|
raise ValueError("EOS token is required.")
|
||||||
|
|
||||||
if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
|
if self.efficient_eos:
|
||||||
eos_ids = []
|
eos_ids = []
|
||||||
else:
|
else:
|
||||||
eos_ids = [tokenizer.eos_token_id]
|
eos_ids = [tokenizer.eos_token_id]
|
||||||
@ -187,9 +189,10 @@ def register_template(
|
|||||||
sep: List[Union[str, Dict[str, str]]],
|
sep: List[Union[str, Dict[str, str]]],
|
||||||
stop_words: Optional[List[str]] = [],
|
stop_words: Optional[List[str]] = [],
|
||||||
use_history: Optional[bool] = True,
|
use_history: Optional[bool] = True,
|
||||||
efficient_eos: Optional[bool] = False
|
efficient_eos: Optional[bool] = False,
|
||||||
|
replace_eos: Optional[bool] = False
|
||||||
) -> None:
|
) -> None:
|
||||||
template_class = Llama2Template if "llama2" in name else Template
|
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||||
templates[name] = template_class(
|
templates[name] = template_class(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -197,7 +200,8 @@ def register_template(
|
|||||||
sep=sep,
|
sep=sep,
|
||||||
stop_words=stop_words,
|
stop_words=stop_words,
|
||||||
use_history=use_history,
|
use_history=use_history,
|
||||||
efficient_eos=efficient_eos
|
efficient_eos=efficient_eos,
|
||||||
|
replace_eos=replace_eos
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -213,15 +217,26 @@ def get_template_and_fix_tokenizer(
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||||
|
|
||||||
if name is None:
|
if name is None: # for pre-training
|
||||||
return None
|
return None
|
||||||
|
|
||||||
template = templates.get(name, None)
|
template = templates.get(name, None)
|
||||||
assert template is not None, "Template {} does not exist.".format(name)
|
assert template is not None, "Template {} does not exist.".format(name)
|
||||||
|
|
||||||
|
if template.replace_eos:
|
||||||
|
if not template.stop_words:
|
||||||
|
raise ValueError("Stop words are required to replace the EOS token.")
|
||||||
|
|
||||||
|
tokenizer.eos_token = template.stop_words.pop(0)
|
||||||
|
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||||
|
|
||||||
|
if template.stop_words:
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
dict(additional_special_tokens=template.stop_words),
|
dict(additional_special_tokens=template.stop_words),
|
||||||
replace_additional_special_tokens=False
|
replace_additional_special_tokens=False
|
||||||
)
|
)
|
||||||
|
logger.info("Add {} to stop words.".format(",".join(template.stop_words)))
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
@ -732,12 +747,12 @@ register_template(
|
|||||||
],
|
],
|
||||||
system="",
|
system="",
|
||||||
sep=[
|
sep=[
|
||||||
"<|im_end|>\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
"<|im_end|>"
|
"<|im_end|>"
|
||||||
],
|
],
|
||||||
efficient_eos=True
|
replace_eos=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user