mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
fix stop words
Former-commit-id: dec360d5aee58c11804b8e45dd8d4c375086887f
This commit is contained in:
parent
622d31e398
commit
9337568e23
@ -1,4 +1,5 @@
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -223,19 +224,20 @@ def get_template_and_fix_tokenizer(
|
|||||||
template = templates.get(name, None)
|
template = templates.get(name, None)
|
||||||
assert template is not None, "Template {} does not exist.".format(name)
|
assert template is not None, "Template {} does not exist.".format(name)
|
||||||
|
|
||||||
|
stop_words = deepcopy(template.stop_words)
|
||||||
if template.replace_eos:
|
if template.replace_eos:
|
||||||
if not template.stop_words:
|
if not stop_words:
|
||||||
raise ValueError("Stop words are required to replace the EOS token.")
|
raise ValueError("Stop words are required to replace the EOS token.")
|
||||||
|
|
||||||
tokenizer.eos_token = template.stop_words.pop(0)
|
tokenizer.eos_token = stop_words.pop(0)
|
||||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||||
|
|
||||||
if template.stop_words:
|
if stop_words:
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
dict(additional_special_tokens=template.stop_words),
|
dict(additional_special_tokens=stop_words),
|
||||||
replace_additional_special_tokens=False
|
replace_additional_special_tokens=False
|
||||||
)
|
)
|
||||||
logger.info("Add {} to stop words.".format(",".join(template.stop_words)))
|
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user