mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] refactor template (#6896)
Former-commit-id: f78d5a3eca947ed965ca2f6c87d60441b1a59867
This commit is contained in:
		
							parent
							
								
									b72c4bd118
								
							
						
					
					
						commit
						3f7bd98bfa
					
				@ -13,7 +13,7 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
 | 
			
		||||
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
@ -68,8 +68,8 @@ class Template:
 | 
			
		||||
        for encoded_ids in encoded_messages[:-1]:
 | 
			
		||||
            prompt_ids += encoded_ids
 | 
			
		||||
 | 
			
		||||
        answer_ids = encoded_messages[-1]
 | 
			
		||||
        return prompt_ids, answer_ids
 | 
			
		||||
        response_ids = encoded_messages[-1]
 | 
			
		||||
        return prompt_ids, response_ids
 | 
			
		||||
 | 
			
		||||
    def encode_multiturn(
 | 
			
		||||
        self,
 | 
			
		||||
@ -100,6 +100,27 @@ class Template:
 | 
			
		||||
 | 
			
		||||
        return list(stop_token_ids)
 | 
			
		||||
 | 
			
		||||
    def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Converts elements to token ids.
 | 
			
		||||
        """
 | 
			
		||||
        token_ids = []
 | 
			
		||||
        for elem in elements:
 | 
			
		||||
            if isinstance(elem, str):
 | 
			
		||||
                if len(elem) != 0:
 | 
			
		||||
                    token_ids += tokenizer.encode(elem, add_special_tokens=False)
 | 
			
		||||
            elif isinstance(elem, dict):
 | 
			
		||||
                token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
 | 
			
		||||
            elif isinstance(elem, set):
 | 
			
		||||
                if "bos_token" in elem and tokenizer.bos_token_id is not None:
 | 
			
		||||
                    token_ids += [tokenizer.bos_token_id]
 | 
			
		||||
                elif "eos_token" in elem and tokenizer.eos_token_id is not None:
 | 
			
		||||
                    token_ids += [tokenizer.eos_token_id]
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
 | 
			
		||||
 | 
			
		||||
        return token_ids
 | 
			
		||||
 | 
			
		||||
    def _encode(
 | 
			
		||||
        self,
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
@ -110,7 +131,7 @@ class Template:
 | 
			
		||||
        r"""
 | 
			
		||||
        Encodes formatted inputs to pairs of token ids.
 | 
			
		||||
        Turn 0: prefix + system + query        resp
 | 
			
		||||
        Turn t: sep + query                    resp
 | 
			
		||||
        Turn t: query                          resp
 | 
			
		||||
        """
 | 
			
		||||
        system = system or self.default_system
 | 
			
		||||
        encoded_messages = []
 | 
			
		||||
@ -138,26 +159,122 @@ class Template:
 | 
			
		||||
 | 
			
		||||
        return encoded_messages
 | 
			
		||||
 | 
			
		||||
    def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Converts elements to token ids.
 | 
			
		||||
        Adds or replaces eos token to the tokenizer.
 | 
			
		||||
        """
 | 
			
		||||
        token_ids = []
 | 
			
		||||
        for elem in elements:
 | 
			
		||||
            if isinstance(elem, str):
 | 
			
		||||
                if len(elem) != 0:
 | 
			
		||||
                    token_ids += tokenizer.encode(elem, add_special_tokens=False)
 | 
			
		||||
            elif isinstance(elem, dict):
 | 
			
		||||
                token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
 | 
			
		||||
            elif isinstance(elem, set):
 | 
			
		||||
                if "bos_token" in elem and tokenizer.bos_token_id is not None:
 | 
			
		||||
                    token_ids += [tokenizer.bos_token_id]
 | 
			
		||||
                elif "eos_token" in elem and tokenizer.eos_token_id is not None:
 | 
			
		||||
                    token_ids += [tokenizer.eos_token_id]
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
 | 
			
		||||
        is_added = tokenizer.eos_token_id is None
 | 
			
		||||
        num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
 | 
			
		||||
 | 
			
		||||
        return token_ids
 | 
			
		||||
        if is_added:
 | 
			
		||||
            logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.")
 | 
			
		||||
        else:
 | 
			
		||||
            logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.")
 | 
			
		||||
 | 
			
		||||
        if num_added_tokens > 0:
 | 
			
		||||
            logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
 | 
			
		||||
 | 
			
		||||
    def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Adds eos token and pad token to the tokenizer.
 | 
			
		||||
        """
 | 
			
		||||
        stop_words = self.stop_words
 | 
			
		||||
        if self.replace_eos:
 | 
			
		||||
            if not stop_words:
 | 
			
		||||
                raise ValueError("Stop words are required to replace the EOS token.")
 | 
			
		||||
 | 
			
		||||
            self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
 | 
			
		||||
            stop_words = stop_words[1:]
 | 
			
		||||
 | 
			
		||||
        if tokenizer.eos_token_id is None:
 | 
			
		||||
            self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
 | 
			
		||||
 | 
			
		||||
        if tokenizer.pad_token_id is None:
 | 
			
		||||
            tokenizer.pad_token = tokenizer.eos_token
 | 
			
		||||
            logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
 | 
			
		||||
 | 
			
		||||
        if stop_words:
 | 
			
		||||
            num_added_tokens = tokenizer.add_special_tokens(
 | 
			
		||||
                dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
 | 
			
		||||
            )
 | 
			
		||||
            logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
 | 
			
		||||
            if num_added_tokens > 0:
 | 
			
		||||
                logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _jinja_escape(content: str) -> str:
 | 
			
		||||
        r"""
 | 
			
		||||
        Escape single quotes in content.
 | 
			
		||||
        """
 | 
			
		||||
        return content.replace("'", r"\'")
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
 | 
			
		||||
        r"""
 | 
			
		||||
        Converts slots to jinja template.
 | 
			
		||||
        """
 | 
			
		||||
        slot_items = []
 | 
			
		||||
        for slot in slots:
 | 
			
		||||
            if isinstance(slot, str):
 | 
			
		||||
                slot_pieces = slot.split("{{content}}")
 | 
			
		||||
                if slot_pieces[0]:
 | 
			
		||||
                    slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'")
 | 
			
		||||
                if len(slot_pieces) > 1:
 | 
			
		||||
                    slot_items.append(placeholder)
 | 
			
		||||
                    if slot_pieces[1]:
 | 
			
		||||
                        slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'")
 | 
			
		||||
            elif isinstance(slot, set):  # do not use {{ eos_token }} since it may be replaced
 | 
			
		||||
                if "bos_token" in slot and tokenizer.bos_token_id is not None:
 | 
			
		||||
                    slot_items.append("'" + tokenizer.bos_token + "'")
 | 
			
		||||
                elif "eos_token" in slot and tokenizer.eos_token_id is not None:
 | 
			
		||||
                    slot_items.append("'" + tokenizer.eos_token + "'")
 | 
			
		||||
            elif isinstance(slot, dict):
 | 
			
		||||
                raise ValueError("Dict is not supported.")
 | 
			
		||||
 | 
			
		||||
        return " + ".join(slot_items)
 | 
			
		||||
 | 
			
		||||
    def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
 | 
			
		||||
        r"""
 | 
			
		||||
        Returns the jinja template.
 | 
			
		||||
        """
 | 
			
		||||
        prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
 | 
			
		||||
        system_message = self._convert_slots_to_jinja(
 | 
			
		||||
            self.format_system.apply(), tokenizer, placeholder="system_message"
 | 
			
		||||
        )
 | 
			
		||||
        user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
 | 
			
		||||
        assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
 | 
			
		||||
        jinja_template = ""
 | 
			
		||||
        if prefix:
 | 
			
		||||
            jinja_template += "{{ " + prefix + " }}"
 | 
			
		||||
 | 
			
		||||
        if self.default_system:
 | 
			
		||||
            jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
 | 
			
		||||
 | 
			
		||||
        jinja_template += (
 | 
			
		||||
            "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
 | 
			
		||||
            "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
 | 
			
		||||
            "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
 | 
			
		||||
            "{% for message in loop_messages %}"
 | 
			
		||||
            "{% set content = message['content'] %}"
 | 
			
		||||
            "{% if message['role'] == 'user' %}"
 | 
			
		||||
            "{{ " + user_message + " }}"
 | 
			
		||||
            "{% elif message['role'] == 'assistant' %}"
 | 
			
		||||
            "{{ " + assistant_message + " }}"
 | 
			
		||||
            "{% endif %}"
 | 
			
		||||
            "{% endfor %}"
 | 
			
		||||
        )
 | 
			
		||||
        return jinja_template
 | 
			
		||||
 | 
			
		||||
    def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Replaces the jinja template in the tokenizer.
 | 
			
		||||
        """
 | 
			
		||||
        if tokenizer.chat_template is None or self.replace_jinja_template:
 | 
			
		||||
            try:
 | 
			
		||||
                tokenizer.chat_template = self._get_jinja_template(tokenizer)
 | 
			
		||||
            except ValueError as e:
 | 
			
		||||
                logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@ -170,11 +287,6 @@ class Llama2Template(Template):
 | 
			
		||||
        system: str,
 | 
			
		||||
        tools: str,
 | 
			
		||||
    ) -> List[List[int]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Encodes formatted inputs to pairs of token ids.
 | 
			
		||||
        Turn 0: prefix + system + query        resp
 | 
			
		||||
        Turn t: sep + query                    resp
 | 
			
		||||
        """
 | 
			
		||||
        system = system or self.default_system
 | 
			
		||||
        encoded_messages = []
 | 
			
		||||
        for i, message in enumerate(messages):
 | 
			
		||||
@ -202,6 +314,36 @@ class Llama2Template(Template):
 | 
			
		||||
 | 
			
		||||
        return encoded_messages
 | 
			
		||||
 | 
			
		||||
    def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
 | 
			
		||||
        prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
 | 
			
		||||
        system_message = self._convert_slots_to_jinja(
 | 
			
		||||
            self.format_system.apply(), tokenizer, placeholder="system_message"
 | 
			
		||||
        )
 | 
			
		||||
        user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
 | 
			
		||||
        assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
 | 
			
		||||
        jinja_template = ""
 | 
			
		||||
        if prefix:
 | 
			
		||||
            jinja_template += "{{ " + prefix + " }}"
 | 
			
		||||
 | 
			
		||||
        if self.default_system:
 | 
			
		||||
            jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
 | 
			
		||||
 | 
			
		||||
        jinja_template += (
 | 
			
		||||
            "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
 | 
			
		||||
            "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
 | 
			
		||||
            "{% for message in loop_messages %}"
 | 
			
		||||
            "{% if loop.index0 == 0 and system_message is defined %}"
 | 
			
		||||
            "{% set content = " + system_message + " + message['content'] %}"
 | 
			
		||||
            "{% else %}{% set content = message['content'] %}{% endif %}"
 | 
			
		||||
            "{% if message['role'] == 'user' %}"
 | 
			
		||||
            "{{ " + user_message + " }}"
 | 
			
		||||
            "{% elif message['role'] == 'assistant' %}"
 | 
			
		||||
            "{{ " + assistant_message + " }}"
 | 
			
		||||
            "{% endif %}"
 | 
			
		||||
            "{% endfor %}"
 | 
			
		||||
        )
 | 
			
		||||
        return jinja_template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TEMPLATES: Dict[str, "Template"] = {}
 | 
			
		||||
 | 
			
		||||
@ -222,7 +364,7 @@ def _register_template(
 | 
			
		||||
    replace_eos: bool = False,
 | 
			
		||||
    replace_jinja_template: bool = False,
 | 
			
		||||
    mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
 | 
			
		||||
    fuse_system_into_user: bool = False,
 | 
			
		||||
    template_class: Type[Template] = Template,
 | 
			
		||||
) -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Registers a chat template.
 | 
			
		||||
@ -245,7 +387,6 @@ def _register_template(
 | 
			
		||||
    )
 | 
			
		||||
    ```
 | 
			
		||||
    """
 | 
			
		||||
    template_class = Llama2Template if fuse_system_into_user else Template
 | 
			
		||||
    default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
 | 
			
		||||
    default_user_formatter = StringFormatter(slots=["{{content}}"])
 | 
			
		||||
    default_assistant_formatter = StringFormatter(slots=default_slots)
 | 
			
		||||
@ -270,86 +411,6 @@ def _register_template(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
 | 
			
		||||
    is_added = tokenizer.eos_token_id is None
 | 
			
		||||
    num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
 | 
			
		||||
 | 
			
		||||
    if is_added:
 | 
			
		||||
        logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
 | 
			
		||||
    else:
 | 
			
		||||
        logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
 | 
			
		||||
 | 
			
		||||
    if num_added_tokens > 0:
 | 
			
		||||
        logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _jinja_escape(content: str) -> str:
 | 
			
		||||
    return content.replace("'", r"\'")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
 | 
			
		||||
    slot_items = []
 | 
			
		||||
    for slot in slots:
 | 
			
		||||
        if isinstance(slot, str):
 | 
			
		||||
            slot_pieces = slot.split("{{content}}")
 | 
			
		||||
            if slot_pieces[0]:
 | 
			
		||||
                slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
 | 
			
		||||
            if len(slot_pieces) > 1:
 | 
			
		||||
                slot_items.append(placeholder)
 | 
			
		||||
                if slot_pieces[1]:
 | 
			
		||||
                    slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
 | 
			
		||||
        elif isinstance(slot, set):  # do not use {{ eos_token }} since it may be replaced
 | 
			
		||||
            if "bos_token" in slot and tokenizer.bos_token_id is not None:
 | 
			
		||||
                slot_items.append("'" + tokenizer.bos_token + "'")
 | 
			
		||||
            elif "eos_token" in slot and tokenizer.eos_token_id is not None:
 | 
			
		||||
                slot_items.append("'" + tokenizer.eos_token + "'")
 | 
			
		||||
        elif isinstance(slot, dict):
 | 
			
		||||
            raise ValueError("Dict is not supported.")
 | 
			
		||||
 | 
			
		||||
    return " + ".join(slot_items)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
 | 
			
		||||
    r"""
 | 
			
		||||
    Returns the jinja template.
 | 
			
		||||
    """
 | 
			
		||||
    jinja_template = ""
 | 
			
		||||
 | 
			
		||||
    prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
 | 
			
		||||
    if prefix:
 | 
			
		||||
        jinja_template += "{{ " + prefix + " }}"
 | 
			
		||||
 | 
			
		||||
    if template.default_system:
 | 
			
		||||
        jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
 | 
			
		||||
 | 
			
		||||
    jinja_template += (
 | 
			
		||||
        "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
 | 
			
		||||
        "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
 | 
			
		||||
    if not isinstance(template, Llama2Template):
 | 
			
		||||
        jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
 | 
			
		||||
 | 
			
		||||
    jinja_template += "{% for message in loop_messages %}"
 | 
			
		||||
    jinja_template += "{% set content = message['content'] %}"
 | 
			
		||||
    if isinstance(template, Llama2Template):
 | 
			
		||||
        jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
 | 
			
		||||
        jinja_template += "{% set content = " + system_message + " + message['content'] %}"
 | 
			
		||||
        jinja_template += "{% endif %}"
 | 
			
		||||
 | 
			
		||||
    jinja_template += "{% if message['role'] == 'user' %}"
 | 
			
		||||
    user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
 | 
			
		||||
    jinja_template += "{{ " + user_message + " }}"
 | 
			
		||||
 | 
			
		||||
    jinja_template += "{% elif message['role'] == 'assistant' %}"
 | 
			
		||||
    assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer)
 | 
			
		||||
    jinja_template += "{{ " + assistant_message + " }}"
 | 
			
		||||
    jinja_template += "{% endif %}"
 | 
			
		||||
    jinja_template += "{% endfor %}"
 | 
			
		||||
    return jinja_template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
 | 
			
		||||
    r"""
 | 
			
		||||
    Gets chat template and fixes the tokenizer.
 | 
			
		||||
@ -373,35 +434,8 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
 | 
			
		||||
        template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
 | 
			
		||||
        template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
 | 
			
		||||
 | 
			
		||||
    stop_words = template.stop_words
 | 
			
		||||
    if template.replace_eos:
 | 
			
		||||
        if not stop_words:
 | 
			
		||||
            raise ValueError("Stop words are required to replace the EOS token.")
 | 
			
		||||
 | 
			
		||||
        _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
 | 
			
		||||
        stop_words = stop_words[1:]
 | 
			
		||||
 | 
			
		||||
    if tokenizer.eos_token_id is None:
 | 
			
		||||
        _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
 | 
			
		||||
 | 
			
		||||
    if tokenizer.pad_token_id is None:
 | 
			
		||||
        tokenizer.pad_token = tokenizer.eos_token
 | 
			
		||||
        logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
 | 
			
		||||
 | 
			
		||||
    if stop_words:
 | 
			
		||||
        num_added_tokens = tokenizer.add_special_tokens(
 | 
			
		||||
            dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
 | 
			
		||||
        )
 | 
			
		||||
        logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
 | 
			
		||||
        if num_added_tokens > 0:
 | 
			
		||||
            logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
 | 
			
		||||
 | 
			
		||||
    if tokenizer.chat_template is None or template.replace_jinja_template:
 | 
			
		||||
        try:
 | 
			
		||||
            tokenizer.chat_template = _get_jinja_template(template, tokenizer)
 | 
			
		||||
        except ValueError as e:
 | 
			
		||||
            logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
 | 
			
		||||
 | 
			
		||||
    template.fix_special_tokens(tokenizer)
 | 
			
		||||
    template.fix_jinja_template(tokenizer)
 | 
			
		||||
    return template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -755,7 +789,7 @@ _register_template(
 | 
			
		||||
    name="llama2",
 | 
			
		||||
    format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
 | 
			
		||||
    fuse_system_into_user=True,
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -765,7 +799,7 @@ _register_template(
 | 
			
		||||
    format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
 | 
			
		||||
    default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
 | 
			
		||||
    fuse_system_into_user=True,
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -889,7 +923,7 @@ _register_template(
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="mistral"),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
 | 
			
		||||
    fuse_system_into_user=True,
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -944,7 +978,7 @@ _register_template(
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="mistral"),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
 | 
			
		||||
    fuse_system_into_user=True,
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1006,7 +1040,7 @@ _register_template(
 | 
			
		||||
    format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="mistral"),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    fuse_system_into_user=True,
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1020,7 +1054,7 @@ _register_template(
 | 
			
		||||
    format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="mistral"),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    fuse_system_into_user=True,
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1137,7 +1171,7 @@ _register_template(
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="mistral"),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
 | 
			
		||||
    fuse_system_into_user=True,
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -104,7 +104,7 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
 | 
			
		||||
    running_log_path = os.path.join(output_path, RUNNING_LOG)
 | 
			
		||||
    if os.path.isfile(running_log_path):
 | 
			
		||||
        with open(running_log_path, encoding="utf-8") as f:
 | 
			
		||||
            running_log = f.read()
 | 
			
		||||
            running_log = f.read()[-20000:]  # avoid lengthy log
 | 
			
		||||
 | 
			
		||||
    trainer_log_path = os.path.join(output_path, TRAINER_LOG)
 | 
			
		||||
    if os.path.isfile(trainer_log_path):
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,6 @@ import pytest
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
 | 
			
		||||
from llamafactory.data import get_template_and_fix_tokenizer
 | 
			
		||||
from llamafactory.data.template import _get_jinja_template
 | 
			
		||||
from llamafactory.hparams import DataArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -115,7 +114,7 @@ def test_jinja_template(use_fast: bool):
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
 | 
			
		||||
    ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
 | 
			
		||||
    tokenizer.chat_template = _get_jinja_template(template, tokenizer)  # llama3 template no replace
 | 
			
		||||
    tokenizer.chat_template = template._get_jinja_template(tokenizer)  # llama3 template no replace
 | 
			
		||||
    assert tokenizer.chat_template != ref_tokenizer.chat_template
 | 
			
		||||
    assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user