mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-31 10:42:50 +08:00
[data] fix gemma2 eos token (#8480)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
parent
bb7bf51554
commit
0a004904bd
@ -263,7 +263,7 @@ Choose your path:
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 |
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
|
@ -265,7 +265,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Falcon-H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/34B | falcon_h1 |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
|
||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 |
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
|
@ -951,6 +951,22 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from gemma template
|
||||
register_template(
|
||||
name="gemma2",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<eos>", "<end_of_turn>"],
|
||||
efficient_eos=True,
|
||||
template_class=Llama2Template,
|
||||
)
|
||||
|
||||
|
||||
# copied from gemma template
|
||||
register_template(
|
||||
name="gemma3",
|
||||
|
@ -712,6 +712,13 @@ register_model_group(
|
||||
"Gemma-1.1-7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Gemma-2-2B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-2b",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b",
|
||||
@ -751,7 +758,7 @@ register_model_group(
|
||||
DownloadSource.MODELSCOPE: "google/medgemma-27b-text-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
template="gemma2",
|
||||
)
|
||||
|
||||
|
||||
|
@ -226,6 +226,19 @@ def test_gemma_template(use_fast: bool):
|
||||
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_gemma2_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
|
||||
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
|
||||
f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
|
||||
"<start_of_turn>model\n"
|
||||
)
|
||||
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
|
||||
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_llama3_template(use_fast: bool):
|
||||
|
@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.3.108
|
||||
0.9.4.100
|
||||
|
Loading…
x
Reference in New Issue
Block a user