diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index ccd322c4..ba043f6c 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -468,17 +468,83 @@ def _register_template( ) +def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": + r""" + Extracts a chat template from the tokenizer. + """ + + def find_diff(short_str: str, long_str: str) -> str: + i, j = 0, 0 + diff = "" + while i < len(short_str) and j < len(long_str): + if short_str[i] == long_str[j]: + i += 1 + j += 1 + else: + diff += long_str[j] + j += 1 + + return diff + + prefix = tokenizer.decode(tokenizer.encode("")) + + messages = [{"role": "system", "content": "{{content}}"}] + system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :] + + messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}] + user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + user_slot_empty_system = user_slot_empty_system[len(prefix) :] + + messages = [{"role": "user", "content": "{{content}}"}] + user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + user_slot = user_slot[len(prefix) :] + + messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}] + assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) + assistant_slot = assistant_slot[len(prefix) + len(user_slot) :] + + if len(user_slot) > len(user_slot_empty_system): + default_system = find_diff(user_slot_empty_system, user_slot) + sole_system = system_slot.replace("{{content}}", default_system, 1) + user_slot = user_slot[len(sole_system) :] + else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot + default_system = "" + + return Template( + format_user=StringFormatter(slots=[user_slot]), + format_assistant=StringFormatter(slots=[assistant_slot]), + format_system=StringFormatter(slots=[system_slot]), + format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"), + format_observation=StringFormatter(slots=[user_slot]), + format_tools=ToolFormatter(tool_format="default"), + format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(), + default_system=default_system, + stop_words=[], + thought_words=("", ""), + efficient_eos=False, + replace_eos=False, + replace_jinja_template=False, + mm_plugin=get_mm_plugin(name="base"), + ) + + def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template": r""" Gets chat template and fixes the tokenizer. """ if data_args.template is None: - template = TEMPLATES["empty"] # placeholder + if isinstance(tokenizer.chat_template, str): + logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.") + template = parse_template(tokenizer) + else: + logger.warning_rank0("`template` was not specified, use `empty` template.") + template = TEMPLATES["empty"] # placeholder else: - template = TEMPLATES.get(data_args.template, None) - if template is None: + if data_args.template not in TEMPLATES: raise ValueError(f"Template {data_args.template} does not exist.") + template = TEMPLATES[data_args.template] + if template.mm_plugin.__class__.__name__ != "BasePlugin": check_version("transformers>=4.45.0") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 95f5b34a..19708156 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -192,9 +192,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ _set_transformers_logging() # Check arguments - if finetuning_args.stage != "pt" and data_args.template is None: - raise ValueError("Please specify which `template` to use.") - if finetuning_args.stage != "sft": if training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True except SFT.") @@ -402,9 +399,6 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ _set_transformers_logging() - if data_args.template is None: - raise ValueError("Please specify which `template` to use.") - if model_args.infer_backend == "vllm": if finetuning_args.stage != "sft": raise ValueError("vLLM engine only supports auto-regressive models.") @@ -435,9 +429,6 @@ def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _E _set_transformers_logging() - if data_args.template is None: - raise ValueError("Please specify which `template` to use.") - if model_args.infer_backend == "vllm": raise ValueError("vLLM backend is only available for API, CLI and Web.") diff --git a/tests/data/test_template.py b/tests/data/test_template.py index 9f8fc835..e2e6b942 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -19,6 +19,7 @@ import pytest from transformers import AutoTokenizer from llamafactory.data import get_template_and_fix_tokenizer +from llamafactory.data.template import parse_template from llamafactory.hparams import DataArguments @@ -208,3 +209,27 @@ def test_yi_template(use_fast: bool): ) answer_str = "很高兴认识你!<|im_end|>\n" _check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str, use_fast) + + +def test_parse_template(): + tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, token=HF_TOKEN) + template = parse_template(tokenizer) + assert template.format_user.slots == [ + "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ] + assert template.format_assistant.slots == ["{{content}}<|eot_id|>"] + assert template.format_system.slots == ["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"] + assert template.format_prefix.slots == ["<|begin_of_text|>"] + assert template.default_system == "" + + +@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") +def test_parse_qwen_template(): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct", token=HF_TOKEN) + template = parse_template(tokenizer) + assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] + assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"] + assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"] + assert template.format_prefix.slots == [] + assert template.default_system == "You are a helpful assistant."