mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] feat: auto template (#6905)
* support auto template * add unittest Former-commit-id: 0c6c9150db6414a5a05527ea486dce6633dff4b3
This commit is contained in:
		
							parent
							
								
									d58fcd094e
								
							
						
					
					
						commit
						2581cc844b
					
				@ -468,17 +468,83 @@ def _register_template(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
 | 
			
		||||
    r"""
 | 
			
		||||
    Extracts a chat template from the tokenizer.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def find_diff(short_str: str, long_str: str) -> str:
 | 
			
		||||
        i, j = 0, 0
 | 
			
		||||
        diff = ""
 | 
			
		||||
        while i < len(short_str) and j < len(long_str):
 | 
			
		||||
            if short_str[i] == long_str[j]:
 | 
			
		||||
                i += 1
 | 
			
		||||
                j += 1
 | 
			
		||||
            else:
 | 
			
		||||
                diff += long_str[j]
 | 
			
		||||
                j += 1
 | 
			
		||||
 | 
			
		||||
        return diff
 | 
			
		||||
 | 
			
		||||
    prefix = tokenizer.decode(tokenizer.encode(""))
 | 
			
		||||
 | 
			
		||||
    messages = [{"role": "system", "content": "{{content}}"}]
 | 
			
		||||
    system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :]
 | 
			
		||||
 | 
			
		||||
    messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}]
 | 
			
		||||
    user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
 | 
			
		||||
    user_slot_empty_system = user_slot_empty_system[len(prefix) :]
 | 
			
		||||
 | 
			
		||||
    messages = [{"role": "user", "content": "{{content}}"}]
 | 
			
		||||
    user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
 | 
			
		||||
    user_slot = user_slot[len(prefix) :]
 | 
			
		||||
 | 
			
		||||
    messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
 | 
			
		||||
    assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
 | 
			
		||||
    assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
 | 
			
		||||
 | 
			
		||||
    if len(user_slot) > len(user_slot_empty_system):
 | 
			
		||||
        default_system = find_diff(user_slot_empty_system, user_slot)
 | 
			
		||||
        sole_system = system_slot.replace("{{content}}", default_system, 1)
 | 
			
		||||
        user_slot = user_slot[len(sole_system) :]
 | 
			
		||||
    else:  # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
 | 
			
		||||
        default_system = ""
 | 
			
		||||
 | 
			
		||||
    return Template(
 | 
			
		||||
        format_user=StringFormatter(slots=[user_slot]),
 | 
			
		||||
        format_assistant=StringFormatter(slots=[assistant_slot]),
 | 
			
		||||
        format_system=StringFormatter(slots=[system_slot]),
 | 
			
		||||
        format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"),
 | 
			
		||||
        format_observation=StringFormatter(slots=[user_slot]),
 | 
			
		||||
        format_tools=ToolFormatter(tool_format="default"),
 | 
			
		||||
        format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(),
 | 
			
		||||
        default_system=default_system,
 | 
			
		||||
        stop_words=[],
 | 
			
		||||
        thought_words=("<think>", "</think>"),
 | 
			
		||||
        efficient_eos=False,
 | 
			
		||||
        replace_eos=False,
 | 
			
		||||
        replace_jinja_template=False,
 | 
			
		||||
        mm_plugin=get_mm_plugin(name="base"),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
 | 
			
		||||
    r"""
 | 
			
		||||
    Gets chat template and fixes the tokenizer.
 | 
			
		||||
    """
 | 
			
		||||
    if data_args.template is None:
 | 
			
		||||
        template = TEMPLATES["empty"]  # placeholder
 | 
			
		||||
        if isinstance(tokenizer.chat_template, str):
 | 
			
		||||
            logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
 | 
			
		||||
            template = parse_template(tokenizer)
 | 
			
		||||
        else:
 | 
			
		||||
            logger.warning_rank0("`template` was not specified, use `empty` template.")
 | 
			
		||||
            template = TEMPLATES["empty"]  # placeholder
 | 
			
		||||
    else:
 | 
			
		||||
        template = TEMPLATES.get(data_args.template, None)
 | 
			
		||||
        if template is None:
 | 
			
		||||
        if data_args.template not in TEMPLATES:
 | 
			
		||||
            raise ValueError(f"Template {data_args.template} does not exist.")
 | 
			
		||||
 | 
			
		||||
        template = TEMPLATES[data_args.template]
 | 
			
		||||
 | 
			
		||||
    if template.mm_plugin.__class__.__name__ != "BasePlugin":
 | 
			
		||||
        check_version("transformers>=4.45.0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -192,9 +192,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
 | 
			
		||||
        _set_transformers_logging()
 | 
			
		||||
 | 
			
		||||
    # Check arguments
 | 
			
		||||
    if finetuning_args.stage != "pt" and data_args.template is None:
 | 
			
		||||
        raise ValueError("Please specify which `template` to use.")
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.stage != "sft":
 | 
			
		||||
        if training_args.predict_with_generate:
 | 
			
		||||
            raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
 | 
			
		||||
@ -402,9 +399,6 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
 | 
			
		||||
 | 
			
		||||
    _set_transformers_logging()
 | 
			
		||||
 | 
			
		||||
    if data_args.template is None:
 | 
			
		||||
        raise ValueError("Please specify which `template` to use.")
 | 
			
		||||
 | 
			
		||||
    if model_args.infer_backend == "vllm":
 | 
			
		||||
        if finetuning_args.stage != "sft":
 | 
			
		||||
            raise ValueError("vLLM engine only supports auto-regressive models.")
 | 
			
		||||
@ -435,9 +429,6 @@ def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _E
 | 
			
		||||
 | 
			
		||||
    _set_transformers_logging()
 | 
			
		||||
 | 
			
		||||
    if data_args.template is None:
 | 
			
		||||
        raise ValueError("Please specify which `template` to use.")
 | 
			
		||||
 | 
			
		||||
    if model_args.infer_backend == "vllm":
 | 
			
		||||
        raise ValueError("vLLM backend is only available for API, CLI and Web.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ import pytest
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
 | 
			
		||||
from llamafactory.data import get_template_and_fix_tokenizer
 | 
			
		||||
from llamafactory.data.template import parse_template
 | 
			
		||||
from llamafactory.hparams import DataArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -208,3 +209,27 @@ def test_yi_template(use_fast: bool):
 | 
			
		||||
    )
 | 
			
		||||
    answer_str = "很高兴认识你!<|im_end|>\n"
 | 
			
		||||
    _check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str, use_fast)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_parse_template():
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, token=HF_TOKEN)
 | 
			
		||||
    template = parse_template(tokenizer)
 | 
			
		||||
    assert template.format_user.slots == [
 | 
			
		||||
        "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
 | 
			
		||||
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
 | 
			
		||||
    ]
 | 
			
		||||
    assert template.format_assistant.slots == ["{{content}}<|eot_id|>"]
 | 
			
		||||
    assert template.format_system.slots == ["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
 | 
			
		||||
    assert template.format_prefix.slots == ["<|begin_of_text|>"]
 | 
			
		||||
    assert template.default_system == ""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
 | 
			
		||||
def test_parse_qwen_template():
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct", token=HF_TOKEN)
 | 
			
		||||
    template = parse_template(tokenizer)
 | 
			
		||||
    assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
 | 
			
		||||
    assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
 | 
			
		||||
    assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
 | 
			
		||||
    assert template.format_prefix.slots == []
 | 
			
		||||
    assert template.default_system == "You are a helpful assistant."
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user