From c6be9e242ca88303a62f7478bfd92b5f5db0101c Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 11 Feb 2025 19:52:25 +0800 Subject: [PATCH] [misc] support export ollama modelfile (#6899) * support export ollama modelfile * update config * add system and num ctx Former-commit-id: 9184a6e0ed7ff5f632c848f861bfa448c4cd06fc --- README.md | 6 +- README_zh.md | 10 +- examples/README.md | 6 ++ examples/README_zh.md | 6 ++ examples/merge_lora/llama3_full_sft.yaml | 11 +++ examples/merge_lora/llama3_gptq.yaml | 4 +- examples/merge_lora/llama3_lora_sft.yaml | 5 +- examples/merge_lora/qwen2vl_lora_sft.yaml | 5 +- scripts/export_ollama_modelfile.py | 112 ---------------------- scripts/test_ollama_modelfile.py | 87 ----------------- src/llamafactory/data/template.py | 74 ++++++++++++-- src/llamafactory/hparams/model_args.py | 2 +- src/llamafactory/train/tuner.py | 6 +- tests/data/test_template.py | 16 ++++ 14 files changed, 126 insertions(+), 224 deletions(-) create mode 100644 examples/merge_lora/llama3_full_sft.yaml delete mode 100644 scripts/export_ollama_modelfile.py delete mode 100644 scripts/test_ollama_modelfile.py diff --git a/README.md b/README.md index 73238452..ccae41c0 100644 --- a/README.md +++ b/README.md @@ -106,16 +106,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage. + [25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks. [25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** model. +
Full Changelog + [25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage. [25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR. -
Full Changelog - [25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. [25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. diff --git a/README_zh.md b/README_zh.md index ca7139b8..94806b41 100644 --- a/README_zh.md +++ b/README_zh.md @@ -108,16 +108,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 更新日志 +[25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。 + +[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。 + [25/01/31] 我们支持了 **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** 和 **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** 模型的微调。 +
展开日志 + [25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。 [25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR. -
展开日志 - -[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。 - [25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。 [25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。 diff --git a/examples/README.md b/examples/README.md index 1b944122..c2d2a52c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -170,6 +170,12 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_gptq.yaml ``` +### Save Ollama modelfile + +```bash +llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml +``` + ### Inferring LoRA Fine-Tuned Models #### Batch Generation using vLLM Tensor Parallel diff --git a/examples/README_zh.md b/examples/README_zh.md index 31d3eda2..316013f7 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -170,6 +170,12 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_gptq.yaml ``` +### 保存 Ollama 配置文件 + +```bash +llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml +``` + ### 推理 LoRA 模型 #### 使用 vLLM+TP 批量推理 diff --git a/examples/merge_lora/llama3_full_sft.yaml b/examples/merge_lora/llama3_full_sft.yaml new file mode 100644 index 00000000..3d5682a7 --- /dev/null +++ b/examples/merge_lora/llama3_full_sft.yaml @@ -0,0 +1,11 @@ +### model +model_name_or_path: saves/llama3-8b/full/sft +template: llama3 +trust_remote_code: true +infer_dtype: bfloat16 + +### export +export_dir: output/llama3_full_sft +export_size: 5 +export_device: cpu +export_legacy_format: false diff --git a/examples/merge_lora/llama3_gptq.yaml b/examples/merge_lora/llama3_gptq.yaml index 21bd05dd..3a2d9095 100644 --- a/examples/merge_lora/llama3_gptq.yaml +++ b/examples/merge_lora/llama3_gptq.yaml @@ -4,9 +4,9 @@ template: llama3 trust_remote_code: true ### export -export_dir: models/llama3_gptq +export_dir: output/llama3_gptq export_quantization_bit: 4 export_quantization_dataset: data/c4_demo.json -export_size: 2 +export_size: 5 export_device: cpu export_legacy_format: false diff --git a/examples/merge_lora/llama3_lora_sft.yaml b/examples/merge_lora/llama3_lora_sft.yaml index 24fc0c18..97bb457b 100644 --- a/examples/merge_lora/llama3_lora_sft.yaml +++ b/examples/merge_lora/llama3_lora_sft.yaml @@ -4,11 +4,10 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct adapter_name_or_path: saves/llama3-8b/lora/sft template: llama3 -finetuning_type: lora trust_remote_code: true ### export -export_dir: models/llama3_lora_sft -export_size: 2 +export_dir: output/llama3_lora_sft +export_size: 5 export_device: cpu export_legacy_format: false diff --git a/examples/merge_lora/qwen2vl_lora_sft.yaml b/examples/merge_lora/qwen2vl_lora_sft.yaml index ebbb4c71..103dbcd8 100644 --- a/examples/merge_lora/qwen2vl_lora_sft.yaml +++ b/examples/merge_lora/qwen2vl_lora_sft.yaml @@ -4,11 +4,10 @@ model_name_or_path: Qwen/Qwen2-VL-7B-Instruct adapter_name_or_path: saves/qwen2_vl-7b/lora/sft template: qwen2_vl -finetuning_type: lora trust_remote_code: true ### export -export_dir: models/qwen2_vl_lora_sft -export_size: 2 +export_dir: output/qwen2_vl_lora_sft +export_size: 5 export_device: cpu export_legacy_format: false diff --git a/scripts/export_ollama_modelfile.py b/scripts/export_ollama_modelfile.py deleted file mode 100644 index 74feebc5..00000000 --- a/scripts/export_ollama_modelfile.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2024 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import codecs -import os -from typing import TYPE_CHECKING - -import fire -from transformers import AutoTokenizer - -from llamafactory.data import get_template_and_fix_tokenizer -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer - from llamafactory.data.formatter import SLOTS - from llamafactory.data.template import Template - - -def _convert_slots_to_ollama(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: - slot_items = [] - for slot in slots: - if isinstance(slot, str): - slot_pieces = slot.split("{{content}}") - if slot_pieces[0]: - slot_items.append(slot_pieces[0]) - if len(slot_pieces) > 1: - slot_items.append(placeholder) - if slot_pieces[1]: - slot_items.append(slot_pieces[1]) - elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced - if "bos_token" in slot and tokenizer.bos_token_id is not None: - slot_items.append(tokenizer.bos_token) - elif "eos_token" in slot and tokenizer.eos_token_id is not None: - slot_items.append(tokenizer.eos_token) - elif isinstance(slot, dict): - raise ValueError("Dict is not supported.") - - return "".join(slot_items) - - -def _split_round_template(user_template_str: "str", template_obj: "Template", tokenizer: "PreTrainedTokenizer") -> tuple: - if template_obj.format_separator.apply(): - format_separator = _convert_slots_to_ollama(template_obj.format_separator.apply(), tokenizer) - round_split_token_list = [tokenizer.eos_token + format_separator, tokenizer.eos_token, - format_separator, "{{ .Prompt }}"] - else: - round_split_token_list = [tokenizer.eos_token, "{{ .Prompt }}"] - - for round_split_token in round_split_token_list: - round_split_templates = user_template_str.split(round_split_token) - if len(round_split_templates) >= 2: - user_round_template = "".join(round_split_templates[:-1]) - assistant_round_template = round_split_templates[-1] - return user_round_template + round_split_token, assistant_round_template - - return user_template_str, "" - - -def convert_template_obj_to_ollama(template_obj: "Template", tokenizer: "PreTrainedTokenizer") -> str: - ollama_template = "" - if template_obj.format_system: - ollama_template += "{{ if .System }}" - ollama_template += _convert_slots_to_ollama(template_obj.format_system.apply(), tokenizer, "{{ .System }}") - ollama_template += "{{ end }}" - - user_template = _convert_slots_to_ollama(template_obj.format_user.apply(), tokenizer, "{{ .Prompt }}") - user_round_template, assistant_round_template = _split_round_template(user_template, template_obj, tokenizer) - - ollama_template += "{{ if .Prompt }}" - ollama_template += user_round_template - ollama_template += "{{ end }}" - ollama_template += assistant_round_template - - ollama_template += _convert_slots_to_ollama(template_obj.format_assistant.apply(), tokenizer, "{{ .Response }}") - - return ollama_template - - -def export_ollama_modelfile( - model_name_or_path: str, - gguf_path: str, - template: str, - export_dir: str = "./ollama_model_file" -): - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) - template_obj = get_template_and_fix_tokenizer(tokenizer, name=template) - ollama_template = convert_template_obj_to_ollama(template_obj, tokenizer) - - if not os.path.exists(export_dir): - os.mkdir(export_dir) - with codecs.open(os.path.join(export_dir, "Modelfile"), "w", encoding="utf-8") as outf: - outf.write("FROM {}".format(gguf_path) + "\n") - outf.write("TEMPLATE \"\"\"{}\"\"\"".format(ollama_template) + "\n") - - if template_obj.stop_words: - for stop_word in template_obj.stop_words: - outf.write("PARAMETER stop \"{}\"".format(stop_word) + "\n") - elif not template_obj.efficient_eos: - outf.write("PARAMETER stop \"{}\"".format(tokenizer.eos_token) + "\n") - - -if __name__ == '__main__': - fire.Fire(export_ollama_modelfile) diff --git a/scripts/test_ollama_modelfile.py b/scripts/test_ollama_modelfile.py deleted file mode 100644 index e094ac67..00000000 --- a/scripts/test_ollama_modelfile.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2024 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers import AutoTokenizer - -from llamafactory.data import get_template_and_fix_tokenizer -from export_ollama_modelfile import convert_template_obj_to_ollama - - -def test_qwen2_template(): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct") - template = get_template_and_fix_tokenizer(tokenizer, name="qwen") - ollama_template = convert_template_obj_to_ollama(template, tokenizer) - - assert ollama_template == ("{{ if .System }}<|im_start|>system\n" - "{{ .System }}<|im_end|>\n" - "{{ end }}{{ if .Prompt }}<|im_start|>user\n" - "{{ .Prompt }}<|im_end|>\n" - "{{ end }}<|im_start|>assistant\n" - "{{ .Response }}<|im_end|>") - - -def test_yi_template(): - tokenizer = AutoTokenizer.from_pretrained("01-ai/Yi-1.5-9B-Chat") - template = get_template_and_fix_tokenizer(tokenizer, name="yi") - ollama_template = convert_template_obj_to_ollama(template, tokenizer) - - assert ollama_template == ("{{ if .System }}<|im_start|>system\n" - "{{ .System }}<|im_end|>\n" - "{{ end }}{{ if .Prompt }}<|im_start|>user\n" - "{{ .Prompt }}<|im_end|>\n" - "{{ end }}<|im_start|>assistant\n" - "{{ .Response }}<|im_end|>") - - -def test_llama2_template(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - template = get_template_and_fix_tokenizer(tokenizer, name="llama2") - ollama_template = convert_template_obj_to_ollama(template, tokenizer) - - assert ollama_template == ("{{ if .System }}<>\n" - "{{ .System }}\n" - "<>\n\n" - "{{ end }}{{ if .Prompt }}[INST] {{ .Prompt }}{{ end }} [/INST]{{ .Response }}") - - -def test_llama3_template(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") - template = get_template_and_fix_tokenizer(tokenizer, name="llama3") - ollama_template = convert_template_obj_to_ollama(template, tokenizer) - - assert ollama_template == ("{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n" - "{{ .System }}<|eot_id|>{{ end }}" - "{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n" - "{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n" - "{{ .Response }}<|eot_id|>") - - -def test_phi3_template(): - tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct") - template = get_template_and_fix_tokenizer(tokenizer, name="phi") - ollama_template = convert_template_obj_to_ollama(template, tokenizer) - assert ollama_template == ("{{ if .System }}<|system|>\n" - "{{ .System }}<|end|>\n" - "{{ end }}{{ if .Prompt }}<|user|>\n" - "{{ .Prompt }}<|end|>\n" - "{{ end }}<|assistant|>\n" - "{{ .Response }}<|end|>") - - -if __name__ == '__main__': - test_qwen2_template() - test_yi_template() - test_llama2_template() - test_llama3_template() - test_phi3_template() diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 47541b54..61cb27a1 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -239,11 +239,9 @@ class Template: Returns the jinja template. """ prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) - system_message = self._convert_slots_to_jinja( - self.format_system.apply(), tokenizer, placeholder="system_message" - ) - user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) - assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer) + system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message") + user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) + assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer) jinja_template = "" if prefix: jinja_template += "{{ " + prefix + " }}" @@ -254,13 +252,13 @@ class Template: jinja_template += ( "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" - "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" + "{% if system_message is defined %}{{ " + system + " }}{% endif %}" "{% for message in loop_messages %}" "{% set content = message['content'] %}" "{% if message['role'] == 'user' %}" - "{{ " + user_message + " }}" + "{{ " + user + " }}" "{% elif message['role'] == 'assistant' %}" - "{{ " + assistant_message + " }}" + "{{ " + assistant + " }}" "{% endif %}" "{% endfor %}" ) @@ -276,6 +274,64 @@ class Template: except ValueError as e: logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.") + @staticmethod + def _convert_slots_to_ollama( + slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content" + ) -> str: + r""" + Converts slots to ollama template. + """ + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append(slot_pieces[0]) + if len(slot_pieces) > 1: + slot_items.append("{{ " + placeholder + " }}") + if slot_pieces[1]: + slot_items.append(slot_pieces[1]) + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: + slot_items.append(tokenizer.bos_token) + elif "eos_token" in slot and tokenizer.eos_token_id is not None: + slot_items.append(tokenizer.eos_token) + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return "".join(slot_items) + + def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str: + r""" + Returns the ollama template. + """ + prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer) + system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System") + user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content") + assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content") + return ( + f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}" + f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}""" + f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}""" + ) + + def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str: + r""" + Returns the ollama modelfile. + + TODO: support function calling. + """ + modelfile = f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n' + + if self.default_system: + modelfile += f'SYSTEM system "{self.default_system}"\n\n' + + for stop_token_id in self.get_stop_token_ids(tokenizer): + modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n' + + modelfile += "PARAMETER num_ctx 4096\n" + return modelfile + @dataclass class Llama2Template(Template): @@ -1020,7 +1076,7 @@ _register_template( ) -# copied from chatml template +# copied from minicpm_v template _register_template( name="minicpm_o", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 7f4df68c..e4429167 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -87,7 +87,7 @@ class ExportArguments: metadata={"help": "Path to the directory to save the exported model."}, ) export_size: int = field( - default=1, + default=5, metadata={"help": "The file shard size (in GB) of the exported model."}, ) export_device: Literal["cpu", "auto"] = field( diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 4e60e2f0..23a39774 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -104,7 +104,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] processor = tokenizer_module["processor"] - get_template_and_fix_tokenizer(tokenizer, data_args) + template = get_template_and_fix_tokenizer(tokenizer, data_args) model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None: @@ -171,3 +171,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: except Exception as e: logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.") + + with open(os.path.join(model_args.export_dir, "Modelfile"), "w", encoding="utf-8") as f: + f.write(template.get_ollama_modelfile(tokenizer)) + logger.info_rank0(f"Saved ollama modelfile to {model_args.export_dir}.") diff --git a/tests/data/test_template.py b/tests/data/test_template.py index 0f1c8ea8..61923678 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -119,6 +119,22 @@ def test_jinja_template(use_fast: bool): assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES) +def test_ollama_modelfile(): + tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) + template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) + assert template.get_ollama_modelfile(tokenizer) == ( + "FROM .\n\n" + 'TEMPLATE """<|begin_of_text|>' + "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}" + '{{ range .Messages }}{{ if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Content }}' + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + '{{ else if eq .Role "assistant" }}{{ .Content }}<|eot_id|>{{ end }}{{ end }}"""\n\n' + 'PARAMETER stop "<|eom_id|>"\n' + 'PARAMETER stop "<|eot_id|>"\n' + "PARAMETER num_ctx 4096\n" + ) + + def test_get_stop_token_ids(): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))