fix stop words

Former-commit-id: dec360d5aee58c11804b8e45dd8d4c375086887f
This commit is contained in:
hiyouga 2023-12-20 19:06:43 +08:00
parent 622d31e398
commit 9337568e23

View File

@ -1,4 +1,5 @@
import tiktoken import tiktoken
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
@ -223,19 +224,20 @@ def get_template_and_fix_tokenizer(
template = templates.get(name, None) template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name) assert template is not None, "Template {} does not exist.".format(name)
stop_words = deepcopy(template.stop_words)
if template.replace_eos: if template.replace_eos:
if not template.stop_words: if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.") raise ValueError("Stop words are required to replace the EOS token.")
tokenizer.eos_token = template.stop_words.pop(0) tokenizer.eos_token = stop_words.pop(0)
logger.info("Replace eos token: {}".format(tokenizer.eos_token)) logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if template.stop_words: if stop_words:
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
dict(additional_special_tokens=template.stop_words), dict(additional_special_tokens=stop_words),
replace_additional_special_tokens=False replace_additional_special_tokens=False
) )
logger.info("Add {} to stop words.".format(",".join(template.stop_words))) logger.info("Add {} to stop words.".format(",".join(stop_words)))
return template return template