diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index 3106c734..47541b54 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing_extensions import override
@@ -68,8 +68,8 @@ class Template:
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
- answer_ids = encoded_messages[-1]
- return prompt_ids, answer_ids
+ response_ids = encoded_messages[-1]
+ return prompt_ids, response_ids
def encode_multiturn(
self,
@@ -100,6 +100,27 @@ class Template:
return list(stop_token_ids)
+ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
+ r"""
+ Converts elements to token ids.
+ """
+ token_ids = []
+ for elem in elements:
+ if isinstance(elem, str):
+ if len(elem) != 0:
+ token_ids += tokenizer.encode(elem, add_special_tokens=False)
+ elif isinstance(elem, dict):
+ token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
+ elif isinstance(elem, set):
+ if "bos_token" in elem and tokenizer.bos_token_id is not None:
+ token_ids += [tokenizer.bos_token_id]
+ elif "eos_token" in elem and tokenizer.eos_token_id is not None:
+ token_ids += [tokenizer.eos_token_id]
+ else:
+ raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
+
+ return token_ids
+
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
@@ -110,7 +131,7 @@ class Template:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
- Turn t: sep + query resp
+ Turn t: query resp
"""
system = system or self.default_system
encoded_messages = []
@@ -138,26 +159,122 @@ class Template:
return encoded_messages
- def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
+ @staticmethod
+ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r"""
- Converts elements to token ids.
+ Adds or replaces eos token to the tokenizer.
"""
- token_ids = []
- for elem in elements:
- if isinstance(elem, str):
- if len(elem) != 0:
- token_ids += tokenizer.encode(elem, add_special_tokens=False)
- elif isinstance(elem, dict):
- token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
- elif isinstance(elem, set):
- if "bos_token" in elem and tokenizer.bos_token_id is not None:
- token_ids += [tokenizer.bos_token_id]
- elif "eos_token" in elem and tokenizer.eos_token_id is not None:
- token_ids += [tokenizer.eos_token_id]
- else:
- raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
+ is_added = tokenizer.eos_token_id is None
+ num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
- return token_ids
+ if is_added:
+ logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.")
+ else:
+ logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.")
+
+ if num_added_tokens > 0:
+ logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
+
+ def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
+ r"""
+ Adds eos token and pad token to the tokenizer.
+ """
+ stop_words = self.stop_words
+ if self.replace_eos:
+ if not stop_words:
+ raise ValueError("Stop words are required to replace the EOS token.")
+
+ self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
+ stop_words = stop_words[1:]
+
+ if tokenizer.eos_token_id is None:
+ self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
+
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
+
+ if stop_words:
+ num_added_tokens = tokenizer.add_special_tokens(
+ dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
+ )
+ logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
+ if num_added_tokens > 0:
+ logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
+
+ @staticmethod
+ def _jinja_escape(content: str) -> str:
+ r"""
+ Escape single quotes in content.
+ """
+ return content.replace("'", r"\'")
+
+ @staticmethod
+ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
+ r"""
+ Converts slots to jinja template.
+ """
+ slot_items = []
+ for slot in slots:
+ if isinstance(slot, str):
+ slot_pieces = slot.split("{{content}}")
+ if slot_pieces[0]:
+ slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'")
+ if len(slot_pieces) > 1:
+ slot_items.append(placeholder)
+ if slot_pieces[1]:
+ slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'")
+ elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
+ if "bos_token" in slot and tokenizer.bos_token_id is not None:
+ slot_items.append("'" + tokenizer.bos_token + "'")
+ elif "eos_token" in slot and tokenizer.eos_token_id is not None:
+ slot_items.append("'" + tokenizer.eos_token + "'")
+ elif isinstance(slot, dict):
+ raise ValueError("Dict is not supported.")
+
+ return " + ".join(slot_items)
+
+ def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
+ r"""
+ Returns the jinja template.
+ """
+ prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
+ system_message = self._convert_slots_to_jinja(
+ self.format_system.apply(), tokenizer, placeholder="system_message"
+ )
+ user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
+ assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
+ jinja_template = ""
+ if prefix:
+ jinja_template += "{{ " + prefix + " }}"
+
+ if self.default_system:
+ jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
+
+ jinja_template += (
+ "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
+ "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
+ "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
+ "{% for message in loop_messages %}"
+ "{% set content = message['content'] %}"
+ "{% if message['role'] == 'user' %}"
+ "{{ " + user_message + " }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ " + assistant_message + " }}"
+ "{% endif %}"
+ "{% endfor %}"
+ )
+ return jinja_template
+
+ def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
+ r"""
+ Replaces the jinja template in the tokenizer.
+ """
+ if tokenizer.chat_template is None or self.replace_jinja_template:
+ try:
+ tokenizer.chat_template = self._get_jinja_template(tokenizer)
+ except ValueError as e:
+ logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
@dataclass
@@ -170,11 +287,6 @@ class Llama2Template(Template):
system: str,
tools: str,
) -> List[List[int]]:
- r"""
- Encodes formatted inputs to pairs of token ids.
- Turn 0: prefix + system + query resp
- Turn t: sep + query resp
- """
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
@@ -202,6 +314,36 @@ class Llama2Template(Template):
return encoded_messages
+ def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
+ prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
+ system_message = self._convert_slots_to_jinja(
+ self.format_system.apply(), tokenizer, placeholder="system_message"
+ )
+ user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
+ assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
+ jinja_template = ""
+ if prefix:
+ jinja_template += "{{ " + prefix + " }}"
+
+ if self.default_system:
+ jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
+
+ jinja_template += (
+ "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
+ "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
+ "{% for message in loop_messages %}"
+ "{% if loop.index0 == 0 and system_message is defined %}"
+ "{% set content = " + system_message + " + message['content'] %}"
+ "{% else %}{% set content = message['content'] %}{% endif %}"
+ "{% if message['role'] == 'user' %}"
+ "{{ " + user_message + " }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ " + assistant_message + " }}"
+ "{% endif %}"
+ "{% endfor %}"
+ )
+ return jinja_template
+
TEMPLATES: Dict[str, "Template"] = {}
@@ -222,7 +364,7 @@ def _register_template(
replace_eos: bool = False,
replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
- fuse_system_into_user: bool = False,
+ template_class: Type[Template] = Template,
) -> None:
r"""
Registers a chat template.
@@ -245,7 +387,6 @@ def _register_template(
)
```
"""
- template_class = Llama2Template if fuse_system_into_user else Template
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=default_slots)
@@ -270,86 +411,6 @@ def _register_template(
)
-def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
- is_added = tokenizer.eos_token_id is None
- num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
-
- if is_added:
- logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
- else:
- logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
-
- if num_added_tokens > 0:
- logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
-
-
-def _jinja_escape(content: str) -> str:
- return content.replace("'", r"\'")
-
-
-def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
- slot_items = []
- for slot in slots:
- if isinstance(slot, str):
- slot_pieces = slot.split("{{content}}")
- if slot_pieces[0]:
- slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
- if len(slot_pieces) > 1:
- slot_items.append(placeholder)
- if slot_pieces[1]:
- slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
- elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
- if "bos_token" in slot and tokenizer.bos_token_id is not None:
- slot_items.append("'" + tokenizer.bos_token + "'")
- elif "eos_token" in slot and tokenizer.eos_token_id is not None:
- slot_items.append("'" + tokenizer.eos_token + "'")
- elif isinstance(slot, dict):
- raise ValueError("Dict is not supported.")
-
- return " + ".join(slot_items)
-
-
-def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
- r"""
- Returns the jinja template.
- """
- jinja_template = ""
-
- prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
- if prefix:
- jinja_template += "{{ " + prefix + " }}"
-
- if template.default_system:
- jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
-
- jinja_template += (
- "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
- "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
- )
-
- system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
- if not isinstance(template, Llama2Template):
- jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
-
- jinja_template += "{% for message in loop_messages %}"
- jinja_template += "{% set content = message['content'] %}"
- if isinstance(template, Llama2Template):
- jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
- jinja_template += "{% set content = " + system_message + " + message['content'] %}"
- jinja_template += "{% endif %}"
-
- jinja_template += "{% if message['role'] == 'user' %}"
- user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
- jinja_template += "{{ " + user_message + " }}"
-
- jinja_template += "{% elif message['role'] == 'assistant' %}"
- assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer)
- jinja_template += "{{ " + assistant_message + " }}"
- jinja_template += "{% endif %}"
- jinja_template += "{% endfor %}"
- return jinja_template
-
-
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
@@ -373,35 +434,8 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
- stop_words = template.stop_words
- if template.replace_eos:
- if not stop_words:
- raise ValueError("Stop words are required to replace the EOS token.")
-
- _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
- stop_words = stop_words[1:]
-
- if tokenizer.eos_token_id is None:
- _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
-
- if tokenizer.pad_token_id is None:
- tokenizer.pad_token = tokenizer.eos_token
- logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
-
- if stop_words:
- num_added_tokens = tokenizer.add_special_tokens(
- dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
- )
- logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
- if num_added_tokens > 0:
- logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
-
- if tokenizer.chat_template is None or template.replace_jinja_template:
- try:
- tokenizer.chat_template = _get_jinja_template(template, tokenizer)
- except ValueError as e:
- logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
-
+ template.fix_special_tokens(tokenizer)
+ template.fix_jinja_template(tokenizer)
return template
@@ -755,7 +789,7 @@ _register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]),
- fuse_system_into_user=True,
+ template_class=Llama2Template,
)
@@ -765,7 +799,7 @@ _register_template(
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]),
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
- fuse_system_into_user=True,
+ template_class=Llama2Template,
)
@@ -889,7 +923,7 @@ _register_template(
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next", image_token=""),
- fuse_system_into_user=True,
+ template_class=Llama2Template,
)
@@ -944,7 +978,7 @@ _register_template(
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="", video_token="