From 33fc7bec85c37672652173f8d04d4518dc56b2c3 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 25 Sep 2024 23:14:17 +0800 Subject: [PATCH] fix ci Former-commit-id: b8e616183cb0252da6efbeac3372b78098b2d6bd --- tests/data/test_template.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/data/test_template.py b/tests/data/test_template.py index a327df22..18d03958 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -19,6 +19,7 @@ import pytest from transformers import AutoTokenizer from llamafactory.data import get_template_and_fix_tokenizer +from llamafactory.data.template import _get_jinja_template from llamafactory.hparams import DataArguments @@ -117,7 +118,8 @@ def test_encode_multiturn(use_fast: bool): def test_jinja_template(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) - get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) + template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) + tokenizer.chat_template = _get_jinja_template(template, tokenizer) # llama3 template no replace assert tokenizer.chat_template != ref_tokenizer.chat_template assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)