mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
[data] refactor template (#6896)
Former-commit-id: f78d5a3eca947ed965ca2f6c87d60441b1a59867
This commit is contained in:
parent
b72c4bd118
commit
3f7bd98bfa
@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -68,8 +68,8 @@ class Template:
|
|||||||
for encoded_ids in encoded_messages[:-1]:
|
for encoded_ids in encoded_messages[:-1]:
|
||||||
prompt_ids += encoded_ids
|
prompt_ids += encoded_ids
|
||||||
|
|
||||||
answer_ids = encoded_messages[-1]
|
response_ids = encoded_messages[-1]
|
||||||
return prompt_ids, answer_ids
|
return prompt_ids, response_ids
|
||||||
|
|
||||||
def encode_multiturn(
|
def encode_multiturn(
|
||||||
self,
|
self,
|
||||||
@ -100,6 +100,27 @@ class Template:
|
|||||||
|
|
||||||
return list(stop_token_ids)
|
return list(stop_token_ids)
|
||||||
|
|
||||||
|
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
|
||||||
|
r"""
|
||||||
|
Converts elements to token ids.
|
||||||
|
"""
|
||||||
|
token_ids = []
|
||||||
|
for elem in elements:
|
||||||
|
if isinstance(elem, str):
|
||||||
|
if len(elem) != 0:
|
||||||
|
token_ids += tokenizer.encode(elem, add_special_tokens=False)
|
||||||
|
elif isinstance(elem, dict):
|
||||||
|
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||||
|
elif isinstance(elem, set):
|
||||||
|
if "bos_token" in elem and tokenizer.bos_token_id is not None:
|
||||||
|
token_ids += [tokenizer.bos_token_id]
|
||||||
|
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
||||||
|
token_ids += [tokenizer.eos_token_id]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
|
||||||
|
|
||||||
|
return token_ids
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
@ -110,7 +131,7 @@ class Template:
|
|||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
Turn 0: prefix + system + query resp
|
Turn 0: prefix + system + query resp
|
||||||
Turn t: sep + query resp
|
Turn t: query resp
|
||||||
"""
|
"""
|
||||||
system = system or self.default_system
|
system = system or self.default_system
|
||||||
encoded_messages = []
|
encoded_messages = []
|
||||||
@ -138,26 +159,122 @@ class Template:
|
|||||||
|
|
||||||
return encoded_messages
|
return encoded_messages
|
||||||
|
|
||||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
|
@staticmethod
|
||||||
|
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
||||||
r"""
|
r"""
|
||||||
Converts elements to token ids.
|
Adds or replaces eos token to the tokenizer.
|
||||||
"""
|
"""
|
||||||
token_ids = []
|
is_added = tokenizer.eos_token_id is None
|
||||||
for elem in elements:
|
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||||
if isinstance(elem, str):
|
|
||||||
if len(elem) != 0:
|
|
||||||
token_ids += tokenizer.encode(elem, add_special_tokens=False)
|
|
||||||
elif isinstance(elem, dict):
|
|
||||||
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
|
||||||
elif isinstance(elem, set):
|
|
||||||
if "bos_token" in elem and tokenizer.bos_token_id is not None:
|
|
||||||
token_ids += [tokenizer.bos_token_id]
|
|
||||||
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
|
||||||
token_ids += [tokenizer.eos_token_id]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
|
|
||||||
|
|
||||||
return token_ids
|
if is_added:
|
||||||
|
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.")
|
||||||
|
else:
|
||||||
|
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.")
|
||||||
|
|
||||||
|
if num_added_tokens > 0:
|
||||||
|
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
|
|
||||||
|
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
r"""
|
||||||
|
Adds eos token and pad token to the tokenizer.
|
||||||
|
"""
|
||||||
|
stop_words = self.stop_words
|
||||||
|
if self.replace_eos:
|
||||||
|
if not stop_words:
|
||||||
|
raise ValueError("Stop words are required to replace the EOS token.")
|
||||||
|
|
||||||
|
self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
|
||||||
|
stop_words = stop_words[1:]
|
||||||
|
|
||||||
|
if tokenizer.eos_token_id is None:
|
||||||
|
self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
|
||||||
|
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
|
||||||
|
|
||||||
|
if stop_words:
|
||||||
|
num_added_tokens = tokenizer.add_special_tokens(
|
||||||
|
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||||
|
)
|
||||||
|
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
|
||||||
|
if num_added_tokens > 0:
|
||||||
|
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _jinja_escape(content: str) -> str:
|
||||||
|
r"""
|
||||||
|
Escape single quotes in content.
|
||||||
|
"""
|
||||||
|
return content.replace("'", r"\'")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
||||||
|
r"""
|
||||||
|
Converts slots to jinja template.
|
||||||
|
"""
|
||||||
|
slot_items = []
|
||||||
|
for slot in slots:
|
||||||
|
if isinstance(slot, str):
|
||||||
|
slot_pieces = slot.split("{{content}}")
|
||||||
|
if slot_pieces[0]:
|
||||||
|
slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'")
|
||||||
|
if len(slot_pieces) > 1:
|
||||||
|
slot_items.append(placeholder)
|
||||||
|
if slot_pieces[1]:
|
||||||
|
slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'")
|
||||||
|
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
|
||||||
|
if "bos_token" in slot and tokenizer.bos_token_id is not None:
|
||||||
|
slot_items.append("'" + tokenizer.bos_token + "'")
|
||||||
|
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
|
||||||
|
slot_items.append("'" + tokenizer.eos_token + "'")
|
||||||
|
elif isinstance(slot, dict):
|
||||||
|
raise ValueError("Dict is not supported.")
|
||||||
|
|
||||||
|
return " + ".join(slot_items)
|
||||||
|
|
||||||
|
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||||
|
r"""
|
||||||
|
Returns the jinja template.
|
||||||
|
"""
|
||||||
|
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
|
||||||
|
system_message = self._convert_slots_to_jinja(
|
||||||
|
self.format_system.apply(), tokenizer, placeholder="system_message"
|
||||||
|
)
|
||||||
|
user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
|
||||||
|
assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
|
||||||
|
jinja_template = ""
|
||||||
|
if prefix:
|
||||||
|
jinja_template += "{{ " + prefix + " }}"
|
||||||
|
|
||||||
|
if self.default_system:
|
||||||
|
jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
|
||||||
|
|
||||||
|
jinja_template += (
|
||||||
|
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
|
||||||
|
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
|
||||||
|
"{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||||
|
"{% for message in loop_messages %}"
|
||||||
|
"{% set content = message['content'] %}"
|
||||||
|
"{% if message['role'] == 'user' %}"
|
||||||
|
"{{ " + user_message + " }}"
|
||||||
|
"{% elif message['role'] == 'assistant' %}"
|
||||||
|
"{{ " + assistant_message + " }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
)
|
||||||
|
return jinja_template
|
||||||
|
|
||||||
|
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
r"""
|
||||||
|
Replaces the jinja template in the tokenizer.
|
||||||
|
"""
|
||||||
|
if tokenizer.chat_template is None or self.replace_jinja_template:
|
||||||
|
try:
|
||||||
|
tokenizer.chat_template = self._get_jinja_template(tokenizer)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -170,11 +287,6 @@ class Llama2Template(Template):
|
|||||||
system: str,
|
system: str,
|
||||||
tools: str,
|
tools: str,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
r"""
|
|
||||||
Encodes formatted inputs to pairs of token ids.
|
|
||||||
Turn 0: prefix + system + query resp
|
|
||||||
Turn t: sep + query resp
|
|
||||||
"""
|
|
||||||
system = system or self.default_system
|
system = system or self.default_system
|
||||||
encoded_messages = []
|
encoded_messages = []
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
@ -202,6 +314,36 @@ class Llama2Template(Template):
|
|||||||
|
|
||||||
return encoded_messages
|
return encoded_messages
|
||||||
|
|
||||||
|
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||||
|
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
|
||||||
|
system_message = self._convert_slots_to_jinja(
|
||||||
|
self.format_system.apply(), tokenizer, placeholder="system_message"
|
||||||
|
)
|
||||||
|
user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
|
||||||
|
assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
|
||||||
|
jinja_template = ""
|
||||||
|
if prefix:
|
||||||
|
jinja_template += "{{ " + prefix + " }}"
|
||||||
|
|
||||||
|
if self.default_system:
|
||||||
|
jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
|
||||||
|
|
||||||
|
jinja_template += (
|
||||||
|
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
|
||||||
|
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
|
||||||
|
"{% for message in loop_messages %}"
|
||||||
|
"{% if loop.index0 == 0 and system_message is defined %}"
|
||||||
|
"{% set content = " + system_message + " + message['content'] %}"
|
||||||
|
"{% else %}{% set content = message['content'] %}{% endif %}"
|
||||||
|
"{% if message['role'] == 'user' %}"
|
||||||
|
"{{ " + user_message + " }}"
|
||||||
|
"{% elif message['role'] == 'assistant' %}"
|
||||||
|
"{{ " + assistant_message + " }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
)
|
||||||
|
return jinja_template
|
||||||
|
|
||||||
|
|
||||||
TEMPLATES: Dict[str, "Template"] = {}
|
TEMPLATES: Dict[str, "Template"] = {}
|
||||||
|
|
||||||
@ -222,7 +364,7 @@ def _register_template(
|
|||||||
replace_eos: bool = False,
|
replace_eos: bool = False,
|
||||||
replace_jinja_template: bool = False,
|
replace_jinja_template: bool = False,
|
||||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||||
fuse_system_into_user: bool = False,
|
template_class: Type[Template] = Template,
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Registers a chat template.
|
Registers a chat template.
|
||||||
@ -245,7 +387,6 @@ def _register_template(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
template_class = Llama2Template if fuse_system_into_user else Template
|
|
||||||
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
|
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
|
||||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||||
default_assistant_formatter = StringFormatter(slots=default_slots)
|
default_assistant_formatter = StringFormatter(slots=default_slots)
|
||||||
@ -270,86 +411,6 @@ def _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
|
||||||
is_added = tokenizer.eos_token_id is None
|
|
||||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
|
||||||
|
|
||||||
if is_added:
|
|
||||||
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
|
|
||||||
else:
|
|
||||||
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
|
|
||||||
|
|
||||||
if num_added_tokens > 0:
|
|
||||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
|
||||||
|
|
||||||
|
|
||||||
def _jinja_escape(content: str) -> str:
|
|
||||||
return content.replace("'", r"\'")
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
|
||||||
slot_items = []
|
|
||||||
for slot in slots:
|
|
||||||
if isinstance(slot, str):
|
|
||||||
slot_pieces = slot.split("{{content}}")
|
|
||||||
if slot_pieces[0]:
|
|
||||||
slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
|
|
||||||
if len(slot_pieces) > 1:
|
|
||||||
slot_items.append(placeholder)
|
|
||||||
if slot_pieces[1]:
|
|
||||||
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
|
|
||||||
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
|
|
||||||
if "bos_token" in slot and tokenizer.bos_token_id is not None:
|
|
||||||
slot_items.append("'" + tokenizer.bos_token + "'")
|
|
||||||
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
|
|
||||||
slot_items.append("'" + tokenizer.eos_token + "'")
|
|
||||||
elif isinstance(slot, dict):
|
|
||||||
raise ValueError("Dict is not supported.")
|
|
||||||
|
|
||||||
return " + ".join(slot_items)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
|
||||||
r"""
|
|
||||||
Returns the jinja template.
|
|
||||||
"""
|
|
||||||
jinja_template = ""
|
|
||||||
|
|
||||||
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
|
||||||
if prefix:
|
|
||||||
jinja_template += "{{ " + prefix + " }}"
|
|
||||||
|
|
||||||
if template.default_system:
|
|
||||||
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
|
||||||
|
|
||||||
jinja_template += (
|
|
||||||
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
|
|
||||||
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
|
|
||||||
)
|
|
||||||
|
|
||||||
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
|
||||||
if not isinstance(template, Llama2Template):
|
|
||||||
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
|
||||||
|
|
||||||
jinja_template += "{% for message in loop_messages %}"
|
|
||||||
jinja_template += "{% set content = message['content'] %}"
|
|
||||||
if isinstance(template, Llama2Template):
|
|
||||||
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
|
|
||||||
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
|
|
||||||
jinja_template += "{% endif %}"
|
|
||||||
|
|
||||||
jinja_template += "{% if message['role'] == 'user' %}"
|
|
||||||
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
|
|
||||||
jinja_template += "{{ " + user_message + " }}"
|
|
||||||
|
|
||||||
jinja_template += "{% elif message['role'] == 'assistant' %}"
|
|
||||||
assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer)
|
|
||||||
jinja_template += "{{ " + assistant_message + " }}"
|
|
||||||
jinja_template += "{% endif %}"
|
|
||||||
jinja_template += "{% endfor %}"
|
|
||||||
return jinja_template
|
|
||||||
|
|
||||||
|
|
||||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
||||||
r"""
|
r"""
|
||||||
Gets chat template and fixes the tokenizer.
|
Gets chat template and fixes the tokenizer.
|
||||||
@ -373,35 +434,8 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
|
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
|
||||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||||
|
|
||||||
stop_words = template.stop_words
|
template.fix_special_tokens(tokenizer)
|
||||||
if template.replace_eos:
|
template.fix_jinja_template(tokenizer)
|
||||||
if not stop_words:
|
|
||||||
raise ValueError("Stop words are required to replace the EOS token.")
|
|
||||||
|
|
||||||
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
|
|
||||||
stop_words = stop_words[1:]
|
|
||||||
|
|
||||||
if tokenizer.eos_token_id is None:
|
|
||||||
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
|
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
|
|
||||||
|
|
||||||
if stop_words:
|
|
||||||
num_added_tokens = tokenizer.add_special_tokens(
|
|
||||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
|
||||||
)
|
|
||||||
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
|
|
||||||
if num_added_tokens > 0:
|
|
||||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
|
||||||
|
|
||||||
if tokenizer.chat_template is None or template.replace_jinja_template:
|
|
||||||
try:
|
|
||||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
@ -755,7 +789,7 @@ _register_template(
|
|||||||
name="llama2",
|
name="llama2",
|
||||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||||
fuse_system_into_user=True,
|
template_class=Llama2Template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -765,7 +799,7 @@ _register_template(
|
|||||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||||
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||||
fuse_system_into_user=True,
|
template_class=Llama2Template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -889,7 +923,7 @@ _register_template(
|
|||||||
format_tools=ToolFormatter(tool_format="mistral"),
|
format_tools=ToolFormatter(tool_format="mistral"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
fuse_system_into_user=True,
|
template_class=Llama2Template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -944,7 +978,7 @@ _register_template(
|
|||||||
format_tools=ToolFormatter(tool_format="mistral"),
|
format_tools=ToolFormatter(tool_format="mistral"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||||
fuse_system_into_user=True,
|
template_class=Llama2Template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1006,7 +1040,7 @@ _register_template(
|
|||||||
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
|
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
|
||||||
format_tools=ToolFormatter(tool_format="mistral"),
|
format_tools=ToolFormatter(tool_format="mistral"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
fuse_system_into_user=True,
|
template_class=Llama2Template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1020,7 +1054,7 @@ _register_template(
|
|||||||
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
|
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
|
||||||
format_tools=ToolFormatter(tool_format="mistral"),
|
format_tools=ToolFormatter(tool_format="mistral"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
fuse_system_into_user=True,
|
template_class=Llama2Template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1137,7 +1171,7 @@ _register_template(
|
|||||||
format_tools=ToolFormatter(tool_format="mistral"),
|
format_tools=ToolFormatter(tool_format="mistral"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
|
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
|
||||||
fuse_system_into_user=True,
|
template_class=Llama2Template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
|
|||||||
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
||||||
if os.path.isfile(running_log_path):
|
if os.path.isfile(running_log_path):
|
||||||
with open(running_log_path, encoding="utf-8") as f:
|
with open(running_log_path, encoding="utf-8") as f:
|
||||||
running_log = f.read()
|
running_log = f.read()[-20000:] # avoid lengthy log
|
||||||
|
|
||||||
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
|
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
|
||||||
if os.path.isfile(trainer_log_path):
|
if os.path.isfile(trainer_log_path):
|
||||||
|
@ -19,7 +19,6 @@ import pytest
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from llamafactory.data import get_template_and_fix_tokenizer
|
from llamafactory.data import get_template_and_fix_tokenizer
|
||||||
from llamafactory.data.template import _get_jinja_template
|
|
||||||
from llamafactory.hparams import DataArguments
|
from llamafactory.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
@ -115,7 +114,7 @@ def test_jinja_template(use_fast: bool):
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer) # llama3 template no replace
|
tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace
|
||||||
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
||||||
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user