diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 54da4757..c7d47ebc 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -49,6 +49,7 @@ class Template: stop_words: List[str] efficient_eos: bool replace_eos: bool + replace_jinja_template: bool mm_plugin: "BasePlugin" def encode_oneturn( @@ -214,6 +215,7 @@ def _register_template( stop_words: Sequence[str] = [], efficient_eos: bool = False, replace_eos: bool = False, + replace_jinja_template: bool = True, mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), ) -> None: r""" @@ -263,6 +265,7 @@ def _register_template( stop_words=stop_words, efficient_eos=efficient_eos, replace_eos=replace_eos, + replace_jinja_template=replace_jinja_template, mm_plugin=mm_plugin, ) @@ -398,10 +401,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: if num_added_tokens > 0: logger.warning("New tokens have been added, make sure `resize_vocab` is True.") - try: - tokenizer.chat_template = _get_jinja_template(template, tokenizer) - except ValueError: - logger.info("Cannot add this chat template to tokenizer.") + if template.replace_jinja_template: + try: + tokenizer.chat_template = _get_jinja_template(template, tokenizer) + except ValueError: + logger.info("Cannot add this chat template to tokenizer.") return template @@ -664,6 +668,7 @@ _register_template( format_separator=EmptyFormatter(slots=["\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), efficient_eos=True, + replace_jinja_template=False, ) @@ -740,6 +745,7 @@ _register_template( format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=["<|eot_id|>"], replace_eos=True, + replace_jinja_template=False, ) @@ -831,6 +837,7 @@ _register_template( default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], replace_eos=True, + replace_jinja_template=False, ) @@ -843,6 +850,7 @@ _register_template( default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], replace_eos=True, + replace_jinja_template=False, mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), )