diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index c917b215..9db10b04 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -156,6 +156,9 @@ def get_dataset( dataset = dataset.to_iterable_dataset() return dataset + if data_args.streaming: + raise ValueError("Turn off `streaming` when saving dataset to disk.") + with training_args.main_process_first(desc="load dataset"): all_datasets = [] for dataset_attr in get_dataset_list(data_args): diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index cb1a1811..895f2698 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -9,7 +9,7 @@ from .utils import Role, infer_max_len if TYPE_CHECKING: from transformers import PreTrainedTokenizer - from .formatter import Formatter + from .formatter import SLOTS, Formatter logger = get_logger(__name__) @@ -276,6 +276,71 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) logger.warning("New tokens have been added, make sure `resize_vocab` is True.") +def _jinja_escape(content: str) -> str: + return content.replace("\n", r"\n").replace("'", r"\'") + + +def _convert_slots_to_jinja( + slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: Optional[str] = "content" +) -> str: + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'") + if len(slot_pieces) > 1: + slot_items.append(placeholder) + if slot_pieces[1]: + slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'") + elif isinstance(slot, set): + if "bos_token" in slot: + slot_items.append("'" + tokenizer.bos_token + "'") + elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced + slot_items.append("'" + tokenizer.eos_token + "'") + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return " + ".join(slot_items) + + +def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: + jinja_template = "" + + if template.default_system: + jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" + + jinja_template += ( + "{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}" + ) + + system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") + if isinstance(template, Llama2Template): + pass + elif template.force_system: + jinja_template += "{{ " + system_message + " }}" + else: + jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" + + jinja_template += "{% for message in messages %}" + jinja_template += "{% set content = message['content'] %}" + if isinstance(template, Llama2Template): + jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" + jinja_template += "{% set content = " + system_message + " + message['content'] %}" + jinja_template += "{% endif %}" + jinja_template += "{% if message['role'] == 'user' %}" + user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer) + jinja_template += "{{ " + user_message + " }}" + jinja_template += "{% elif message['role'] == 'assistant' %}" + assistant_message = _convert_slots_to_jinja( + template.format_assistant.apply() + template.format_separator.apply(), tokenizer + ) + jinja_template += "{{ " + assistant_message + " }}" + jinja_template += "{% endif %}" + jinja_template += "{% endfor %}" + return jinja_template + + def get_template_and_fix_tokenizer( tokenizer: "PreTrainedTokenizer", name: Optional[str] = None, @@ -308,6 +373,11 @@ def get_template_and_fix_tokenizer( ) logger.info("Add {} to stop words.".format(",".join(stop_words))) + try: + tokenizer.chat_template = _get_jinja_template(template, tokenizer) + except ValueError: + logger.info("Cannot add this chat template to tokenizer.") + return template @@ -345,14 +415,14 @@ _register_template( _register_template( name="baichuan", - format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), + format_user=StringFormatter(slots=["{{content}}"]), efficient_eos=True, ) _register_template( name="baichuan2", - format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), + format_user=StringFormatter(slots=["{{content}}"]), efficient_eos=True, ) @@ -465,7 +535,7 @@ _register_template( name="deepseekcoder", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]), - format_separator=EmptyFormatter(slots=["\n", {"token": "<|EOT|>"}, "\n"]), + format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]), default_system=( "You are an AI programming assistant, utilizing the Deepseek Coder model, " "developed by Deepseek Company, and you only answer questions related to computer science. " @@ -600,10 +670,8 @@ _register_template( _register_template( name="starchat", - format_user=StringFormatter( - slots=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}] - ), - format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]), + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|end|>"], replace_eos=True, @@ -684,6 +752,7 @@ _register_template( _register_template( name="zephyr", format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), default_system="You are a friendly chatbot who always responds in the style of a pirate", ) @@ -691,6 +760,6 @@ _register_template( _register_template( name="ziya", - format_user=StringFormatter(slots=[{"token": ""}, ":{{content}}\n", {"token": ""}, ":"]), + format_user=StringFormatter(slots=[":{{content}}\n:"]), format_separator=EmptyFormatter(slots=["\n"]), )