[data] feat: auto template (#6905)

* support auto template

* add unittest

Former-commit-id: 2f8b6847f5e199d770e91346dfe205c4b9f1fbb7
This commit is contained in:
hoshi-hiyouga 2025-02-12 00:22:53 +08:00 committed by GitHub
parent 1b02183da9
commit 2e2f6bea07
3 changed files with 94 additions and 12 deletions

View File

@ -468,17 +468,83 @@ def _register_template(
)
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
r"""
Extracts a chat template from the tokenizer.
"""
def find_diff(short_str: str, long_str: str) -> str:
i, j = 0, 0
diff = ""
while i < len(short_str) and j < len(long_str):
if short_str[i] == long_str[j]:
i += 1
j += 1
else:
diff += long_str[j]
j += 1
return diff
prefix = tokenizer.decode(tokenizer.encode(""))
messages = [{"role": "system", "content": "{{content}}"}]
system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :]
messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}]
user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
user_slot_empty_system = user_slot_empty_system[len(prefix) :]
messages = [{"role": "user", "content": "{{content}}"}]
user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
user_slot = user_slot[len(prefix) :]
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
if len(user_slot) > len(user_slot_empty_system):
default_system = find_diff(user_slot_empty_system, user_slot)
sole_system = system_slot.replace("{{content}}", default_system, 1)
user_slot = user_slot[len(sole_system) :]
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system = ""
return Template(
format_user=StringFormatter(slots=[user_slot]),
format_assistant=StringFormatter(slots=[assistant_slot]),
format_system=StringFormatter(slots=[system_slot]),
format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"),
format_observation=StringFormatter(slots=[user_slot]),
format_tools=ToolFormatter(tool_format="default"),
format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(),
default_system=default_system,
stop_words=[],
thought_words=("<think>", "</think>"),
efficient_eos=False,
replace_eos=False,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="base"),
)
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
if data_args.template is None:
template = TEMPLATES["empty"] # placeholder
if isinstance(tokenizer.chat_template, str):
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
template = parse_template(tokenizer)
else:
logger.warning_rank0("`template` was not specified, use `empty` template.")
template = TEMPLATES["empty"] # placeholder
else:
template = TEMPLATES.get(data_args.template, None)
if template is None:
if data_args.template not in TEMPLATES:
raise ValueError(f"Template {data_args.template} does not exist.")
template = TEMPLATES[data_args.template]
if template.mm_plugin.__class__.__name__ != "BasePlugin":
check_version("transformers>=4.45.0")

View File

@ -192,9 +192,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
_set_transformers_logging()
# Check arguments
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft":
if training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
@ -402,9 +399,6 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
_set_transformers_logging()
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.infer_backend == "vllm":
if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
@ -435,9 +429,6 @@ def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _E
_set_transformers_logging()
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")

View File

@ -19,6 +19,7 @@ import pytest
from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.template import parse_template
from llamafactory.hparams import DataArguments
@ -208,3 +209,27 @@ def test_yi_template(use_fast: bool):
)
answer_str = "很高兴认识你!<|im_end|>\n"
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str, use_fast)
def test_parse_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, token=HF_TOKEN)
template = parse_template(tokenizer)
assert template.format_user.slots == [
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
]
assert template.format_assistant.slots == ["{{content}}<|eot_id|>"]
assert template.format_system.slots == ["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
assert template.format_prefix.slots == ["<|begin_of_text|>"]
assert template.default_system == ""
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct", token=HF_TOKEN)
template = parse_template(tokenizer)
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
assert template.format_prefix.slots == []
assert template.default_system == "You are a helpful assistant."