diff --git a/tests/data/test_template.py b/tests/data/test_template.py index a327df22..18d03958 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -19,6 +19,7 @@ import pytest from transformers import AutoTokenizer from llamafactory.data import get_template_and_fix_tokenizer +from llamafactory.data.template import _get_jinja_template from llamafactory.hparams import DataArguments @@ -117,7 +118,8 @@ def test_encode_multiturn(use_fast: bool): def test_jinja_template(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) - get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) + template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) + tokenizer.chat_template = _get_jinja_template(template, tokenizer) # llama3 template no replace assert tokenizer.chat_template != ref_tokenizer.chat_template assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)