mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[data] feat: auto template (#6905)
* support auto template * add unittest Former-commit-id: 2f8b6847f5e199d770e91346dfe205c4b9f1fbb7
This commit is contained in:
parent
1b02183da9
commit
2e2f6bea07
@ -468,17 +468,83 @@ def _register_template(
|
||||
)
|
||||
|
||||
|
||||
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
r"""
|
||||
Extracts a chat template from the tokenizer.
|
||||
"""
|
||||
|
||||
def find_diff(short_str: str, long_str: str) -> str:
|
||||
i, j = 0, 0
|
||||
diff = ""
|
||||
while i < len(short_str) and j < len(long_str):
|
||||
if short_str[i] == long_str[j]:
|
||||
i += 1
|
||||
j += 1
|
||||
else:
|
||||
diff += long_str[j]
|
||||
j += 1
|
||||
|
||||
return diff
|
||||
|
||||
prefix = tokenizer.decode(tokenizer.encode(""))
|
||||
|
||||
messages = [{"role": "system", "content": "{{content}}"}]
|
||||
system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :]
|
||||
|
||||
messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}]
|
||||
user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
user_slot_empty_system = user_slot_empty_system[len(prefix) :]
|
||||
|
||||
messages = [{"role": "user", "content": "{{content}}"}]
|
||||
user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
user_slot = user_slot[len(prefix) :]
|
||||
|
||||
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
|
||||
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
|
||||
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
|
||||
|
||||
if len(user_slot) > len(user_slot_empty_system):
|
||||
default_system = find_diff(user_slot_empty_system, user_slot)
|
||||
sole_system = system_slot.replace("{{content}}", default_system, 1)
|
||||
user_slot = user_slot[len(sole_system) :]
|
||||
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
|
||||
default_system = ""
|
||||
|
||||
return Template(
|
||||
format_user=StringFormatter(slots=[user_slot]),
|
||||
format_assistant=StringFormatter(slots=[assistant_slot]),
|
||||
format_system=StringFormatter(slots=[system_slot]),
|
||||
format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"),
|
||||
format_observation=StringFormatter(slots=[user_slot]),
|
||||
format_tools=ToolFormatter(tool_format="default"),
|
||||
format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(),
|
||||
default_system=default_system,
|
||||
stop_words=[],
|
||||
thought_words=("<think>", "</think>"),
|
||||
efficient_eos=False,
|
||||
replace_eos=False,
|
||||
replace_jinja_template=False,
|
||||
mm_plugin=get_mm_plugin(name="base"),
|
||||
)
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
||||
r"""
|
||||
Gets chat template and fixes the tokenizer.
|
||||
"""
|
||||
if data_args.template is None:
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
if isinstance(tokenizer.chat_template, str):
|
||||
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
|
||||
template = parse_template(tokenizer)
|
||||
else:
|
||||
logger.warning_rank0("`template` was not specified, use `empty` template.")
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
else:
|
||||
template = TEMPLATES.get(data_args.template, None)
|
||||
if template is None:
|
||||
if data_args.template not in TEMPLATES:
|
||||
raise ValueError(f"Template {data_args.template} does not exist.")
|
||||
|
||||
template = TEMPLATES[data_args.template]
|
||||
|
||||
if template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||
check_version("transformers>=4.45.0")
|
||||
|
||||
|
@ -192,9 +192,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
_set_transformers_logging()
|
||||
|
||||
# Check arguments
|
||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if finetuning_args.stage != "sft":
|
||||
if training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
@ -402,9 +399,6 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
|
||||
_set_transformers_logging()
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
if finetuning_args.stage != "sft":
|
||||
raise ValueError("vLLM engine only supports auto-regressive models.")
|
||||
@ -435,9 +429,6 @@ def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _E
|
||||
|
||||
_set_transformers_logging()
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
|
@ -19,6 +19,7 @@ import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.data import get_template_and_fix_tokenizer
|
||||
from llamafactory.data.template import parse_template
|
||||
from llamafactory.hparams import DataArguments
|
||||
|
||||
|
||||
@ -208,3 +209,27 @@ def test_yi_template(use_fast: bool):
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
def test_parse_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, token=HF_TOKEN)
|
||||
template = parse_template(tokenizer)
|
||||
assert template.format_user.slots == [
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
]
|
||||
assert template.format_assistant.slots == ["{{content}}<|eot_id|>"]
|
||||
assert template.format_system.slots == ["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
|
||||
assert template.format_prefix.slots == ["<|begin_of_text|>"]
|
||||
assert template.default_system == ""
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_parse_qwen_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct", token=HF_TOKEN)
|
||||
template = parse_template(tokenizer)
|
||||
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
|
||||
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||
assert template.format_prefix.slots == []
|
||||
assert template.default_system == "You are a helpful assistant."
|
||||
|
Loading…
x
Reference in New Issue
Block a user