From a9a652eb6fa3279564da47df63687b8a7835d38a Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 15 Jul 2024 00:49:34 +0800 Subject: [PATCH] update test template Former-commit-id: a4ae3ab4ab8e3c6ad9feba4c185e3b592eda3f09 --- tests/data/test_template.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/data/test_template.py b/tests/data/test_template.py index 3dd83546..1ce1b0d8 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -25,6 +25,8 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer +HF_TOKEN = os.environ.get("HF_TOKEN", None) + TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") MESSAGES = [ @@ -46,7 +48,7 @@ def _check_tokenization( def _check_single_template( model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool ): - tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=os.environ.get("HF_TOKEN", None)) + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN) content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False) content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True) template = get_template_and_fix_tokenizer(tokenizer, name=template_name) @@ -119,6 +121,7 @@ def test_jinja_template(use_fast: bool): assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES) +@pytest.mark.skipif(HF_TOKEN is None, reason="Gated model.") def test_gemma_template(): prompt_str = ( "user\nHow are you\n" @@ -130,6 +133,7 @@ def test_gemma_template(): _check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, extra_str="\n") +@pytest.mark.skipif(HF_TOKEN is None, reason="Gated model.") def test_llama3_template(): prompt_str = ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"