optionally replace jinja template

Former-commit-id: ba52103ba7f8f0c856691c4a9a706a6e19e73c1e
This commit is contained in:
hiyouga 2024-09-25 23:02:02 +08:00
parent 52a6667da6
commit f30e0a75c4

View File

@ -49,6 +49,7 @@ class Template:
stop_words: List[str] stop_words: List[str]
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool
mm_plugin: "BasePlugin" mm_plugin: "BasePlugin"
def encode_oneturn( def encode_oneturn(
@ -214,6 +215,7 @@ def _register_template(
stop_words: Sequence[str] = [], stop_words: Sequence[str] = [],
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None: ) -> None:
r""" r"""
@ -263,6 +265,7 @@ def _register_template(
stop_words=stop_words, stop_words=stop_words,
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
mm_plugin=mm_plugin, mm_plugin=mm_plugin,
) )
@ -398,10 +401,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
try: if template.replace_jinja_template:
tokenizer.chat_template = _get_jinja_template(template, tokenizer) try:
except ValueError: tokenizer.chat_template = _get_jinja_template(template, tokenizer)
logger.info("Cannot add this chat template to tokenizer.") except ValueError:
logger.info("Cannot add this chat template to tokenizer.")
return template return template
@ -664,6 +668,7 @@ _register_template(
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]), format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True, efficient_eos=True,
replace_jinja_template=False,
) )
@ -740,6 +745,7 @@ _register_template(
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"], stop_words=["<|eot_id|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
) )
@ -831,6 +837,7 @@ _register_template(
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
) )
@ -843,6 +850,7 @@ _register_template(
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
) )