diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 3106c734..47541b54 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union from typing_extensions import override @@ -68,8 +68,8 @@ class Template: for encoded_ids in encoded_messages[:-1]: prompt_ids += encoded_ids - answer_ids = encoded_messages[-1] - return prompt_ids, answer_ids + response_ids = encoded_messages[-1] + return prompt_ids, response_ids def encode_multiturn( self, @@ -100,6 +100,27 @@ class Template: return list(stop_token_ids) + def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]: + r""" + Converts elements to token ids. + """ + token_ids = [] + for elem in elements: + if isinstance(elem, str): + if len(elem) != 0: + token_ids += tokenizer.encode(elem, add_special_tokens=False) + elif isinstance(elem, dict): + token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))] + elif isinstance(elem, set): + if "bos_token" in elem and tokenizer.bos_token_id is not None: + token_ids += [tokenizer.bos_token_id] + elif "eos_token" in elem and tokenizer.eos_token_id is not None: + token_ids += [tokenizer.eos_token_id] + else: + raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}") + + return token_ids + def _encode( self, tokenizer: "PreTrainedTokenizer", @@ -110,7 +131,7 @@ class Template: r""" Encodes formatted inputs to pairs of token ids. Turn 0: prefix + system + query resp - Turn t: sep + query resp + Turn t: query resp """ system = system or self.default_system encoded_messages = [] @@ -138,26 +159,122 @@ class Template: return encoded_messages - def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]: + @staticmethod + def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: r""" - Converts elements to token ids. + Adds or replaces eos token to the tokenizer. """ - token_ids = [] - for elem in elements: - if isinstance(elem, str): - if len(elem) != 0: - token_ids += tokenizer.encode(elem, add_special_tokens=False) - elif isinstance(elem, dict): - token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))] - elif isinstance(elem, set): - if "bos_token" in elem and tokenizer.bos_token_id is not None: - token_ids += [tokenizer.bos_token_id] - elif "eos_token" in elem and tokenizer.eos_token_id is not None: - token_ids += [tokenizer.eos_token_id] - else: - raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}") + is_added = tokenizer.eos_token_id is None + num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) - return token_ids + if is_added: + logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.") + else: + logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.") + + if num_added_tokens > 0: + logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") + + def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None: + r""" + Adds eos token and pad token to the tokenizer. + """ + stop_words = self.stop_words + if self.replace_eos: + if not stop_words: + raise ValueError("Stop words are required to replace the EOS token.") + + self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0]) + stop_words = stop_words[1:] + + if tokenizer.eos_token_id is None: + self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>") + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info_rank0(f"Add pad token: {tokenizer.pad_token}") + + if stop_words: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False + ) + logger.info_rank0("Add {} to stop words.".format(",".join(stop_words))) + if num_added_tokens > 0: + logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") + + @staticmethod + def _jinja_escape(content: str) -> str: + r""" + Escape single quotes in content. + """ + return content.replace("'", r"\'") + + @staticmethod + def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: + r""" + Converts slots to jinja template. + """ + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'") + if len(slot_pieces) > 1: + slot_items.append(placeholder) + if slot_pieces[1]: + slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'") + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: + slot_items.append("'" + tokenizer.bos_token + "'") + elif "eos_token" in slot and tokenizer.eos_token_id is not None: + slot_items.append("'" + tokenizer.eos_token + "'") + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return " + ".join(slot_items) + + def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str: + r""" + Returns the jinja template. + """ + prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) + system_message = self._convert_slots_to_jinja( + self.format_system.apply(), tokenizer, placeholder="system_message" + ) + user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) + assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer) + jinja_template = "" + if prefix: + jinja_template += "{{ " + prefix + " }}" + + if self.default_system: + jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}" + + jinja_template += ( + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" + "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" + "{% for message in loop_messages %}" + "{% set content = message['content'] %}" + "{% if message['role'] == 'user' %}" + "{{ " + user_message + " }}" + "{% elif message['role'] == 'assistant' %}" + "{{ " + assistant_message + " }}" + "{% endif %}" + "{% endfor %}" + ) + return jinja_template + + def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None: + r""" + Replaces the jinja template in the tokenizer. + """ + if tokenizer.chat_template is None or self.replace_jinja_template: + try: + tokenizer.chat_template = self._get_jinja_template(tokenizer) + except ValueError as e: + logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.") @dataclass @@ -170,11 +287,6 @@ class Llama2Template(Template): system: str, tools: str, ) -> List[List[int]]: - r""" - Encodes formatted inputs to pairs of token ids. - Turn 0: prefix + system + query resp - Turn t: sep + query resp - """ system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): @@ -202,6 +314,36 @@ class Llama2Template(Template): return encoded_messages + def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str: + prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) + system_message = self._convert_slots_to_jinja( + self.format_system.apply(), tokenizer, placeholder="system_message" + ) + user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) + assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer) + jinja_template = "" + if prefix: + jinja_template += "{{ " + prefix + " }}" + + if self.default_system: + jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}" + + jinja_template += ( + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" + "{% for message in loop_messages %}" + "{% if loop.index0 == 0 and system_message is defined %}" + "{% set content = " + system_message + " + message['content'] %}" + "{% else %}{% set content = message['content'] %}{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ " + user_message + " }}" + "{% elif message['role'] == 'assistant' %}" + "{{ " + assistant_message + " }}" + "{% endif %}" + "{% endfor %}" + ) + return jinja_template + TEMPLATES: Dict[str, "Template"] = {} @@ -222,7 +364,7 @@ def _register_template( replace_eos: bool = False, replace_jinja_template: bool = False, mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), - fuse_system_into_user: bool = False, + template_class: Type[Template] = Template, ) -> None: r""" Registers a chat template. @@ -245,7 +387,6 @@ def _register_template( ) ``` """ - template_class = Llama2Template if fuse_system_into_user else Template default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=default_slots) @@ -270,86 +411,6 @@ def _register_template( ) -def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: - is_added = tokenizer.eos_token_id is None - num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) - - if is_added: - logger.info_rank0(f"Add eos token: {tokenizer.eos_token}") - else: - logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}") - - if num_added_tokens > 0: - logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") - - -def _jinja_escape(content: str) -> str: - return content.replace("'", r"\'") - - -def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: - slot_items = [] - for slot in slots: - if isinstance(slot, str): - slot_pieces = slot.split("{{content}}") - if slot_pieces[0]: - slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'") - if len(slot_pieces) > 1: - slot_items.append(placeholder) - if slot_pieces[1]: - slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'") - elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced - if "bos_token" in slot and tokenizer.bos_token_id is not None: - slot_items.append("'" + tokenizer.bos_token + "'") - elif "eos_token" in slot and tokenizer.eos_token_id is not None: - slot_items.append("'" + tokenizer.eos_token + "'") - elif isinstance(slot, dict): - raise ValueError("Dict is not supported.") - - return " + ".join(slot_items) - - -def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: - r""" - Returns the jinja template. - """ - jinja_template = "" - - prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer) - if prefix: - jinja_template += "{{ " + prefix + " }}" - - if template.default_system: - jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" - - jinja_template += ( - "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" - ) - - system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") - if not isinstance(template, Llama2Template): - jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" - - jinja_template += "{% for message in loop_messages %}" - jinja_template += "{% set content = message['content'] %}" - if isinstance(template, Llama2Template): - jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" - jinja_template += "{% set content = " + system_message + " + message['content'] %}" - jinja_template += "{% endif %}" - - jinja_template += "{% if message['role'] == 'user' %}" - user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer) - jinja_template += "{{ " + user_message + " }}" - - jinja_template += "{% elif message['role'] == 'assistant' %}" - assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer) - jinja_template += "{{ " + assistant_message + " }}" - jinja_template += "{% endif %}" - jinja_template += "{% endfor %}" - return jinja_template - - def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template": r""" Gets chat template and fixes the tokenizer. @@ -373,35 +434,8 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format) - stop_words = template.stop_words - if template.replace_eos: - if not stop_words: - raise ValueError("Stop words are required to replace the EOS token.") - - _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0]) - stop_words = stop_words[1:] - - if tokenizer.eos_token_id is None: - _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>") - - if tokenizer.pad_token_id is None: - tokenizer.pad_token = tokenizer.eos_token - logger.info_rank0(f"Add pad token: {tokenizer.pad_token}") - - if stop_words: - num_added_tokens = tokenizer.add_special_tokens( - dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False - ) - logger.info_rank0("Add {} to stop words.".format(",".join(stop_words))) - if num_added_tokens > 0: - logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") - - if tokenizer.chat_template is None or template.replace_jinja_template: - try: - tokenizer.chat_template = _get_jinja_template(template, tokenizer) - except ValueError as e: - logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.") - + template.fix_special_tokens(tokenizer) + template.fix_jinja_template(tokenizer) return template @@ -755,7 +789,7 @@ _register_template( name="llama2", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), - fuse_system_into_user=True, + template_class=Llama2Template, ) @@ -765,7 +799,7 @@ _register_template( format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), default_system="You are a helpful assistant. 你是一个乐于助人的助手。", - fuse_system_into_user=True, + template_class=Llama2Template, ) @@ -889,7 +923,7 @@ _register_template( format_tools=ToolFormatter(tool_format="mistral"), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), mm_plugin=get_mm_plugin(name="llava_next", image_token=""), - fuse_system_into_user=True, + template_class=Llama2Template, ) @@ -944,7 +978,7 @@ _register_template( format_tools=ToolFormatter(tool_format="mistral"), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), mm_plugin=get_mm_plugin(name="llava_next_video", image_token="", video_token="