mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
[data] feat: auto template (#6905)
* support auto template * add unittest Former-commit-id: 0c6c9150db6414a5a05527ea486dce6633dff4b3
This commit is contained in:
parent
d58fcd094e
commit
2581cc844b
@ -468,17 +468,83 @@ def _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||||
|
r"""
|
||||||
|
Extracts a chat template from the tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def find_diff(short_str: str, long_str: str) -> str:
|
||||||
|
i, j = 0, 0
|
||||||
|
diff = ""
|
||||||
|
while i < len(short_str) and j < len(long_str):
|
||||||
|
if short_str[i] == long_str[j]:
|
||||||
|
i += 1
|
||||||
|
j += 1
|
||||||
|
else:
|
||||||
|
diff += long_str[j]
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
return diff
|
||||||
|
|
||||||
|
prefix = tokenizer.decode(tokenizer.encode(""))
|
||||||
|
|
||||||
|
messages = [{"role": "system", "content": "{{content}}"}]
|
||||||
|
system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :]
|
||||||
|
|
||||||
|
messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}]
|
||||||
|
user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
user_slot_empty_system = user_slot_empty_system[len(prefix) :]
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "{{content}}"}]
|
||||||
|
user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
user_slot = user_slot[len(prefix) :]
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
|
||||||
|
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
|
||||||
|
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
|
||||||
|
|
||||||
|
if len(user_slot) > len(user_slot_empty_system):
|
||||||
|
default_system = find_diff(user_slot_empty_system, user_slot)
|
||||||
|
sole_system = system_slot.replace("{{content}}", default_system, 1)
|
||||||
|
user_slot = user_slot[len(sole_system) :]
|
||||||
|
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
|
||||||
|
default_system = ""
|
||||||
|
|
||||||
|
return Template(
|
||||||
|
format_user=StringFormatter(slots=[user_slot]),
|
||||||
|
format_assistant=StringFormatter(slots=[assistant_slot]),
|
||||||
|
format_system=StringFormatter(slots=[system_slot]),
|
||||||
|
format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"),
|
||||||
|
format_observation=StringFormatter(slots=[user_slot]),
|
||||||
|
format_tools=ToolFormatter(tool_format="default"),
|
||||||
|
format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(),
|
||||||
|
default_system=default_system,
|
||||||
|
stop_words=[],
|
||||||
|
thought_words=("<think>", "</think>"),
|
||||||
|
efficient_eos=False,
|
||||||
|
replace_eos=False,
|
||||||
|
replace_jinja_template=False,
|
||||||
|
mm_plugin=get_mm_plugin(name="base"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
||||||
r"""
|
r"""
|
||||||
Gets chat template and fixes the tokenizer.
|
Gets chat template and fixes the tokenizer.
|
||||||
"""
|
"""
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
template = TEMPLATES["empty"] # placeholder
|
if isinstance(tokenizer.chat_template, str):
|
||||||
|
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
|
||||||
|
template = parse_template(tokenizer)
|
||||||
|
else:
|
||||||
|
logger.warning_rank0("`template` was not specified, use `empty` template.")
|
||||||
|
template = TEMPLATES["empty"] # placeholder
|
||||||
else:
|
else:
|
||||||
template = TEMPLATES.get(data_args.template, None)
|
if data_args.template not in TEMPLATES:
|
||||||
if template is None:
|
|
||||||
raise ValueError(f"Template {data_args.template} does not exist.")
|
raise ValueError(f"Template {data_args.template} does not exist.")
|
||||||
|
|
||||||
|
template = TEMPLATES[data_args.template]
|
||||||
|
|
||||||
if template.mm_plugin.__class__.__name__ != "BasePlugin":
|
if template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||||
check_version("transformers>=4.45.0")
|
check_version("transformers>=4.45.0")
|
||||||
|
|
||||||
|
@ -192,9 +192,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
|||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
|
|
||||||
# Check arguments
|
# Check arguments
|
||||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
|
||||||
raise ValueError("Please specify which `template` to use.")
|
|
||||||
|
|
||||||
if finetuning_args.stage != "sft":
|
if finetuning_args.stage != "sft":
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||||
@ -402,9 +399,6 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
|||||||
|
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
|
|
||||||
if data_args.template is None:
|
|
||||||
raise ValueError("Please specify which `template` to use.")
|
|
||||||
|
|
||||||
if model_args.infer_backend == "vllm":
|
if model_args.infer_backend == "vllm":
|
||||||
if finetuning_args.stage != "sft":
|
if finetuning_args.stage != "sft":
|
||||||
raise ValueError("vLLM engine only supports auto-regressive models.")
|
raise ValueError("vLLM engine only supports auto-regressive models.")
|
||||||
@ -435,9 +429,6 @@ def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _E
|
|||||||
|
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
|
|
||||||
if data_args.template is None:
|
|
||||||
raise ValueError("Please specify which `template` to use.")
|
|
||||||
|
|
||||||
if model_args.infer_backend == "vllm":
|
if model_args.infer_backend == "vllm":
|
||||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ import pytest
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from llamafactory.data import get_template_and_fix_tokenizer
|
from llamafactory.data import get_template_and_fix_tokenizer
|
||||||
|
from llamafactory.data.template import parse_template
|
||||||
from llamafactory.hparams import DataArguments
|
from llamafactory.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
@ -208,3 +209,27 @@ def test_yi_template(use_fast: bool):
|
|||||||
)
|
)
|
||||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
answer_str = "很高兴认识你!<|im_end|>\n"
|
||||||
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str, use_fast)
|
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str, use_fast)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_template():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, token=HF_TOKEN)
|
||||||
|
template = parse_template(tokenizer)
|
||||||
|
assert template.format_user.slots == [
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
]
|
||||||
|
assert template.format_assistant.slots == ["{{content}}<|eot_id|>"]
|
||||||
|
assert template.format_system.slots == ["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
|
||||||
|
assert template.format_prefix.slots == ["<|begin_of_text|>"]
|
||||||
|
assert template.default_system == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||||
|
def test_parse_qwen_template():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct", token=HF_TOKEN)
|
||||||
|
template = parse_template(tokenizer)
|
||||||
|
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||||
|
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
|
||||||
|
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||||
|
assert template.format_prefix.slots == []
|
||||||
|
assert template.default_system == "You are a helpful assistant."
|
||||||
|
Loading…
x
Reference in New Issue
Block a user