From 0a004904bd9f655d5048bdd457c5f76444389fe1 Mon Sep 17 00:00:00 2001 From: Liu Jiajun <939282975@qq.com> Date: Fri, 27 Jun 2025 18:19:15 +0800 Subject: [PATCH] [data] fix gemma2 eos token (#8480) Co-authored-by: Yaowei Zheng --- README.md | 2 +- README_zh.md | 2 +- src/llamafactory/data/template.py | 16 ++++++++++++++++ src/llamafactory/extras/constants.py | 9 ++++++++- tests/data/test_template.py | 13 +++++++++++++ tests/version.txt | 2 +- 6 files changed, 40 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d86a5323..cfd76f9a 100644 --- a/README.md +++ b/README.md @@ -263,7 +263,7 @@ Choose your path: | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 | -| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | +| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 | | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) | | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | diff --git a/README_zh.md b/README_zh.md index 530218bf..836b2ffa 100644 --- a/README_zh.md +++ b/README_zh.md @@ -265,7 +265,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 | -| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | +| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 | | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) | | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index b4eda7f5..b3e2132b 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -951,6 +951,22 @@ register_template( ) +# copied from gemma template +register_template( + name="gemma2", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["", ""], + efficient_eos=True, + template_class=Llama2Template, +) + + # copied from gemma template register_template( name="gemma3", diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index a1a31c5a..6459168d 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -712,6 +712,13 @@ register_model_group( "Gemma-1.1-7B-Instruct": { DownloadSource.DEFAULT: "google/gemma-1.1-7b-it", }, + }, + template="gemma", +) + + +register_model_group( + models={ "Gemma-2-2B": { DownloadSource.DEFAULT: "google/gemma-2-2b", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b", @@ -751,7 +758,7 @@ register_model_group( DownloadSource.MODELSCOPE: "google/medgemma-27b-text-it", }, }, - template="gemma", + template="gemma2", ) diff --git a/tests/data/test_template.py b/tests/data/test_template.py index 4a9aa061..dd2deca8 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -226,6 +226,19 @@ def test_gemma_template(use_fast: bool): _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast) +@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") +@pytest.mark.parametrize("use_fast", [True, False]) +def test_gemma2_template(use_fast: bool): + prompt_str = ( + f"user\n{MESSAGES[0]['content']}\n" + f"model\n{MESSAGES[1]['content']}\n" + f"user\n{MESSAGES[2]['content']}\n" + "model\n" + ) + answer_str = f"{MESSAGES[3]['content']}\n" + _check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast) + + @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.parametrize("use_fast", [True, False]) def test_llama3_template(use_fast: bool): diff --git a/tests/version.txt b/tests/version.txt index 0f1383aa..ff9e9d5e 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.3.108 +0.9.4.100