mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[misc] support export ollama modelfile (#6899)
* support export ollama modelfile * update config * add system and num ctx Former-commit-id: 8c2af7466f4015f300b51841db11bcd2505ebf20
This commit is contained in:
		
							parent
							
								
									3f7bd98bfa
								
							
						
					
					
						commit
						88eafd865b
					
				@ -106,16 +106,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
## Changelog
 | 
			
		||||
 | 
			
		||||
[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
 | 
			
		||||
 | 
			
		||||
[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
 | 
			
		||||
 | 
			
		||||
[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** model.
 | 
			
		||||
 | 
			
		||||
<details><summary>Full Changelog</summary>
 | 
			
		||||
 | 
			
		||||
[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
 | 
			
		||||
 | 
			
		||||
[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
 | 
			
		||||
 | 
			
		||||
<details><summary>Full Changelog</summary>
 | 
			
		||||
 | 
			
		||||
[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
 | 
			
		||||
 | 
			
		||||
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										10
									
								
								README_zh.md
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								README_zh.md
									
									
									
									
									
								
							@ -108,16 +108,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
 | 
			
		||||
## 更新日志
 | 
			
		||||
 | 
			
		||||
[25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。
 | 
			
		||||
 | 
			
		||||
[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。
 | 
			
		||||
 | 
			
		||||
[25/01/31] 我们支持了 **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** 和 **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** 模型的微调。
 | 
			
		||||
 | 
			
		||||
<details><summary>展开日志</summary>
 | 
			
		||||
 | 
			
		||||
[25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
 | 
			
		||||
 | 
			
		||||
[25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR.
 | 
			
		||||
 | 
			
		||||
<details><summary>展开日志</summary>
 | 
			
		||||
 | 
			
		||||
[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。
 | 
			
		||||
 | 
			
		||||
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
 | 
			
		||||
 | 
			
		||||
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
 | 
			
		||||
 | 
			
		||||
@ -170,6 +170,12 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
 | 
			
		||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Save Ollama modelfile
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Inferring LoRA Fine-Tuned Models
 | 
			
		||||
 | 
			
		||||
#### Batch Generation using vLLM Tensor Parallel
 | 
			
		||||
 | 
			
		||||
@ -170,6 +170,12 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
 | 
			
		||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 保存 Ollama 配置文件
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 推理 LoRA 模型
 | 
			
		||||
 | 
			
		||||
#### 使用 vLLM+TP 批量推理
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										11
									
								
								examples/merge_lora/llama3_full_sft.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								examples/merge_lora/llama3_full_sft.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,11 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: saves/llama3-8b/full/sft
 | 
			
		||||
template: llama3
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
infer_dtype: bfloat16
 | 
			
		||||
 | 
			
		||||
### export
 | 
			
		||||
export_dir: output/llama3_full_sft
 | 
			
		||||
export_size: 5
 | 
			
		||||
export_device: cpu
 | 
			
		||||
export_legacy_format: false
 | 
			
		||||
@ -4,9 +4,9 @@ template: llama3
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### export
 | 
			
		||||
export_dir: models/llama3_gptq
 | 
			
		||||
export_dir: output/llama3_gptq
 | 
			
		||||
export_quantization_bit: 4
 | 
			
		||||
export_quantization_dataset: data/c4_demo.json
 | 
			
		||||
export_size: 2
 | 
			
		||||
export_size: 5
 | 
			
		||||
export_device: cpu
 | 
			
		||||
export_legacy_format: false
 | 
			
		||||
 | 
			
		||||
@ -4,11 +4,10 @@
 | 
			
		||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
 | 
			
		||||
adapter_name_or_path: saves/llama3-8b/lora/sft
 | 
			
		||||
template: llama3
 | 
			
		||||
finetuning_type: lora
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### export
 | 
			
		||||
export_dir: models/llama3_lora_sft
 | 
			
		||||
export_size: 2
 | 
			
		||||
export_dir: output/llama3_lora_sft
 | 
			
		||||
export_size: 5
 | 
			
		||||
export_device: cpu
 | 
			
		||||
export_legacy_format: false
 | 
			
		||||
 | 
			
		||||
@ -4,11 +4,10 @@
 | 
			
		||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
 | 
			
		||||
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
 | 
			
		||||
template: qwen2_vl
 | 
			
		||||
finetuning_type: lora
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### export
 | 
			
		||||
export_dir: models/qwen2_vl_lora_sft
 | 
			
		||||
export_size: 2
 | 
			
		||||
export_dir: output/qwen2_vl_lora_sft
 | 
			
		||||
export_size: 5
 | 
			
		||||
export_device: cpu
 | 
			
		||||
export_legacy_format: false
 | 
			
		||||
 | 
			
		||||
@ -1,112 +0,0 @@
 | 
			
		||||
# Copyright 2024 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import codecs
 | 
			
		||||
import os
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
import fire
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
 | 
			
		||||
from llamafactory.data import get_template_and_fix_tokenizer
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer
 | 
			
		||||
    from llamafactory.data.formatter import SLOTS
 | 
			
		||||
    from llamafactory.data.template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_slots_to_ollama(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
 | 
			
		||||
    slot_items = []
 | 
			
		||||
    for slot in slots:
 | 
			
		||||
        if isinstance(slot, str):
 | 
			
		||||
            slot_pieces = slot.split("{{content}}")
 | 
			
		||||
            if slot_pieces[0]:
 | 
			
		||||
                slot_items.append(slot_pieces[0])
 | 
			
		||||
            if len(slot_pieces) > 1:
 | 
			
		||||
                slot_items.append(placeholder)
 | 
			
		||||
                if slot_pieces[1]:
 | 
			
		||||
                    slot_items.append(slot_pieces[1])
 | 
			
		||||
        elif isinstance(slot, set):  # do not use {{ eos_token }} since it may be replaced
 | 
			
		||||
            if "bos_token" in slot and tokenizer.bos_token_id is not None:
 | 
			
		||||
                slot_items.append(tokenizer.bos_token)
 | 
			
		||||
            elif "eos_token" in slot and tokenizer.eos_token_id is not None:
 | 
			
		||||
                slot_items.append(tokenizer.eos_token)
 | 
			
		||||
        elif isinstance(slot, dict):
 | 
			
		||||
            raise ValueError("Dict is not supported.")
 | 
			
		||||
 | 
			
		||||
    return "".join(slot_items)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _split_round_template(user_template_str: "str", template_obj: "Template", tokenizer: "PreTrainedTokenizer") -> tuple:
 | 
			
		||||
    if template_obj.format_separator.apply():
 | 
			
		||||
        format_separator = _convert_slots_to_ollama(template_obj.format_separator.apply(), tokenizer)
 | 
			
		||||
        round_split_token_list = [tokenizer.eos_token + format_separator, tokenizer.eos_token,
 | 
			
		||||
                              format_separator, "{{ .Prompt }}"]
 | 
			
		||||
    else:
 | 
			
		||||
        round_split_token_list = [tokenizer.eos_token, "{{ .Prompt }}"]
 | 
			
		||||
 | 
			
		||||
    for round_split_token in round_split_token_list:
 | 
			
		||||
        round_split_templates = user_template_str.split(round_split_token)
 | 
			
		||||
        if len(round_split_templates) >= 2:
 | 
			
		||||
            user_round_template = "".join(round_split_templates[:-1])
 | 
			
		||||
            assistant_round_template = round_split_templates[-1]
 | 
			
		||||
            return user_round_template + round_split_token, assistant_round_template
 | 
			
		||||
 | 
			
		||||
    return user_template_str, ""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_template_obj_to_ollama(template_obj: "Template", tokenizer: "PreTrainedTokenizer") -> str:
 | 
			
		||||
    ollama_template = ""
 | 
			
		||||
    if template_obj.format_system:
 | 
			
		||||
        ollama_template += "{{ if .System }}"
 | 
			
		||||
        ollama_template += _convert_slots_to_ollama(template_obj.format_system.apply(), tokenizer, "{{ .System }}")
 | 
			
		||||
        ollama_template += "{{ end }}"
 | 
			
		||||
 | 
			
		||||
    user_template = _convert_slots_to_ollama(template_obj.format_user.apply(), tokenizer, "{{ .Prompt }}")
 | 
			
		||||
    user_round_template, assistant_round_template = _split_round_template(user_template, template_obj, tokenizer)
 | 
			
		||||
 | 
			
		||||
    ollama_template += "{{ if .Prompt }}"
 | 
			
		||||
    ollama_template += user_round_template
 | 
			
		||||
    ollama_template += "{{ end }}"
 | 
			
		||||
    ollama_template += assistant_round_template
 | 
			
		||||
 | 
			
		||||
    ollama_template += _convert_slots_to_ollama(template_obj.format_assistant.apply(), tokenizer, "{{ .Response }}")
 | 
			
		||||
 | 
			
		||||
    return ollama_template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export_ollama_modelfile(
 | 
			
		||||
    model_name_or_path: str,
 | 
			
		||||
    gguf_path: str,
 | 
			
		||||
    template: str,
 | 
			
		||||
    export_dir: str = "./ollama_model_file"
 | 
			
		||||
):
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
 | 
			
		||||
    template_obj = get_template_and_fix_tokenizer(tokenizer, name=template)
 | 
			
		||||
    ollama_template = convert_template_obj_to_ollama(template_obj, tokenizer)
 | 
			
		||||
 | 
			
		||||
    if not os.path.exists(export_dir):
 | 
			
		||||
        os.mkdir(export_dir)
 | 
			
		||||
    with codecs.open(os.path.join(export_dir, "Modelfile"), "w", encoding="utf-8") as outf:
 | 
			
		||||
        outf.write("FROM {}".format(gguf_path) + "\n")
 | 
			
		||||
        outf.write("TEMPLATE \"\"\"{}\"\"\"".format(ollama_template) + "\n")
 | 
			
		||||
 | 
			
		||||
        if template_obj.stop_words:
 | 
			
		||||
            for stop_word in template_obj.stop_words:
 | 
			
		||||
                outf.write("PARAMETER stop \"{}\"".format(stop_word) + "\n")
 | 
			
		||||
        elif not template_obj.efficient_eos:
 | 
			
		||||
            outf.write("PARAMETER stop \"{}\"".format(tokenizer.eos_token) + "\n")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    fire.Fire(export_ollama_modelfile)
 | 
			
		||||
@ -1,87 +0,0 @@
 | 
			
		||||
# Copyright 2024 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
 | 
			
		||||
from llamafactory.data import get_template_and_fix_tokenizer
 | 
			
		||||
from export_ollama_modelfile import convert_template_obj_to_ollama
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_qwen2_template():
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, name="qwen")
 | 
			
		||||
    ollama_template = convert_template_obj_to_ollama(template, tokenizer)
 | 
			
		||||
 | 
			
		||||
    assert ollama_template == ("{{ if .System }}<|im_start|>system\n"
 | 
			
		||||
                               "{{ .System }}<|im_end|>\n"
 | 
			
		||||
                               "{{ end }}{{ if .Prompt }}<|im_start|>user\n"
 | 
			
		||||
                               "{{ .Prompt }}<|im_end|>\n"
 | 
			
		||||
                               "{{ end }}<|im_start|>assistant\n"
 | 
			
		||||
                               "{{ .Response }}<|im_end|>")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_yi_template():
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained("01-ai/Yi-1.5-9B-Chat")
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, name="yi")
 | 
			
		||||
    ollama_template = convert_template_obj_to_ollama(template, tokenizer)
 | 
			
		||||
 | 
			
		||||
    assert ollama_template == ("{{ if .System }}<|im_start|>system\n"
 | 
			
		||||
                               "{{ .System }}<|im_end|>\n"
 | 
			
		||||
                               "{{ end }}{{ if .Prompt }}<|im_start|>user\n"
 | 
			
		||||
                               "{{ .Prompt }}<|im_end|>\n"
 | 
			
		||||
                               "{{ end }}<|im_start|>assistant\n"
 | 
			
		||||
                               "{{ .Response }}<|im_end|>")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_llama2_template():
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, name="llama2")
 | 
			
		||||
    ollama_template = convert_template_obj_to_ollama(template, tokenizer)
 | 
			
		||||
 | 
			
		||||
    assert ollama_template == ("{{ if .System }}<<SYS>>\n"
 | 
			
		||||
                               "{{ .System }}\n"
 | 
			
		||||
                               "<</SYS>>\n\n"
 | 
			
		||||
                               "{{ end }}{{ if .Prompt }}<s>[INST] {{ .Prompt }}{{ end }} [/INST]{{ .Response }}</s>")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_llama3_template():
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
 | 
			
		||||
    ollama_template = convert_template_obj_to_ollama(template, tokenizer)
 | 
			
		||||
 | 
			
		||||
    assert ollama_template == ("{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n"
 | 
			
		||||
                               "{{ .System }}<|eot_id|>{{ end }}"
 | 
			
		||||
                               "{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n"
 | 
			
		||||
                               "{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n"
 | 
			
		||||
                               "{{ .Response }}<|eot_id|>")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_phi3_template():
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, name="phi")
 | 
			
		||||
    ollama_template = convert_template_obj_to_ollama(template, tokenizer)
 | 
			
		||||
    assert ollama_template == ("{{ if .System }}<|system|>\n"
 | 
			
		||||
                               "{{ .System }}<|end|>\n"
 | 
			
		||||
                               "{{ end }}{{ if .Prompt }}<|user|>\n"
 | 
			
		||||
                               "{{ .Prompt }}<|end|>\n"
 | 
			
		||||
                               "{{ end }}<|assistant|>\n"
 | 
			
		||||
                               "{{ .Response }}<|end|>")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    test_qwen2_template()
 | 
			
		||||
    test_yi_template()
 | 
			
		||||
    test_llama2_template()
 | 
			
		||||
    test_llama3_template()
 | 
			
		||||
    test_phi3_template()
 | 
			
		||||
@ -239,11 +239,9 @@ class Template:
 | 
			
		||||
        Returns the jinja template.
 | 
			
		||||
        """
 | 
			
		||||
        prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
 | 
			
		||||
        system_message = self._convert_slots_to_jinja(
 | 
			
		||||
            self.format_system.apply(), tokenizer, placeholder="system_message"
 | 
			
		||||
        )
 | 
			
		||||
        user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
 | 
			
		||||
        assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
 | 
			
		||||
        system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
 | 
			
		||||
        user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
 | 
			
		||||
        assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
 | 
			
		||||
        jinja_template = ""
 | 
			
		||||
        if prefix:
 | 
			
		||||
            jinja_template += "{{ " + prefix + " }}"
 | 
			
		||||
@ -254,13 +252,13 @@ class Template:
 | 
			
		||||
        jinja_template += (
 | 
			
		||||
            "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
 | 
			
		||||
            "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
 | 
			
		||||
            "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
 | 
			
		||||
            "{% if system_message is defined %}{{ " + system + " }}{% endif %}"
 | 
			
		||||
            "{% for message in loop_messages %}"
 | 
			
		||||
            "{% set content = message['content'] %}"
 | 
			
		||||
            "{% if message['role'] == 'user' %}"
 | 
			
		||||
            "{{ " + user_message + " }}"
 | 
			
		||||
            "{{ " + user + " }}"
 | 
			
		||||
            "{% elif message['role'] == 'assistant' %}"
 | 
			
		||||
            "{{ " + assistant_message + " }}"
 | 
			
		||||
            "{{ " + assistant + " }}"
 | 
			
		||||
            "{% endif %}"
 | 
			
		||||
            "{% endfor %}"
 | 
			
		||||
        )
 | 
			
		||||
@ -276,6 +274,64 @@ class Template:
 | 
			
		||||
            except ValueError as e:
 | 
			
		||||
                logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _convert_slots_to_ollama(
 | 
			
		||||
        slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        r"""
 | 
			
		||||
        Converts slots to ollama template.
 | 
			
		||||
        """
 | 
			
		||||
        slot_items = []
 | 
			
		||||
        for slot in slots:
 | 
			
		||||
            if isinstance(slot, str):
 | 
			
		||||
                slot_pieces = slot.split("{{content}}")
 | 
			
		||||
                if slot_pieces[0]:
 | 
			
		||||
                    slot_items.append(slot_pieces[0])
 | 
			
		||||
                if len(slot_pieces) > 1:
 | 
			
		||||
                    slot_items.append("{{ " + placeholder + " }}")
 | 
			
		||||
                    if slot_pieces[1]:
 | 
			
		||||
                        slot_items.append(slot_pieces[1])
 | 
			
		||||
            elif isinstance(slot, set):  # do not use {{ eos_token }} since it may be replaced
 | 
			
		||||
                if "bos_token" in slot and tokenizer.bos_token_id is not None:
 | 
			
		||||
                    slot_items.append(tokenizer.bos_token)
 | 
			
		||||
                elif "eos_token" in slot and tokenizer.eos_token_id is not None:
 | 
			
		||||
                    slot_items.append(tokenizer.eos_token)
 | 
			
		||||
            elif isinstance(slot, dict):
 | 
			
		||||
                raise ValueError("Dict is not supported.")
 | 
			
		||||
 | 
			
		||||
        return "".join(slot_items)
 | 
			
		||||
 | 
			
		||||
    def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
 | 
			
		||||
        r"""
 | 
			
		||||
        Returns the ollama template.
 | 
			
		||||
        """
 | 
			
		||||
        prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
 | 
			
		||||
        system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
 | 
			
		||||
        user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
 | 
			
		||||
        assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content")
 | 
			
		||||
        return (
 | 
			
		||||
            f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}"
 | 
			
		||||
            f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}"""
 | 
			
		||||
            f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}"""
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
 | 
			
		||||
        r"""
 | 
			
		||||
        Returns the ollama modelfile.
 | 
			
		||||
 | 
			
		||||
        TODO: support function calling.
 | 
			
		||||
        """
 | 
			
		||||
        modelfile = f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
 | 
			
		||||
 | 
			
		||||
        if self.default_system:
 | 
			
		||||
            modelfile += f'SYSTEM system "{self.default_system}"\n\n'
 | 
			
		||||
 | 
			
		||||
        for stop_token_id in self.get_stop_token_ids(tokenizer):
 | 
			
		||||
            modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'
 | 
			
		||||
 | 
			
		||||
        modelfile += "PARAMETER num_ctx 4096\n"
 | 
			
		||||
        return modelfile
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Llama2Template(Template):
 | 
			
		||||
@ -1020,7 +1076,7 @@ _register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from chatml template
 | 
			
		||||
# copied from minicpm_v template
 | 
			
		||||
_register_template(
 | 
			
		||||
    name="minicpm_o",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
 | 
			
		||||
@ -87,7 +87,7 @@ class ExportArguments:
 | 
			
		||||
        metadata={"help": "Path to the directory to save the exported model."},
 | 
			
		||||
    )
 | 
			
		||||
    export_size: int = field(
 | 
			
		||||
        default=1,
 | 
			
		||||
        default=5,
 | 
			
		||||
        metadata={"help": "The file shard size (in GB) of the exported model."},
 | 
			
		||||
    )
 | 
			
		||||
    export_device: Literal["cpu", "auto"] = field(
 | 
			
		||||
 | 
			
		||||
@ -104,7 +104,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
 | 
			
		||||
    tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
    tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
    processor = tokenizer_module["processor"]
 | 
			
		||||
    get_template_and_fix_tokenizer(tokenizer, data_args)
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args)  # must after fixing tokenizer to resize vocab
 | 
			
		||||
 | 
			
		||||
    if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
 | 
			
		||||
@ -171,3 +171,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
 | 
			
		||||
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
 | 
			
		||||
 | 
			
		||||
    with open(os.path.join(model_args.export_dir, "Modelfile"), "w", encoding="utf-8") as f:
 | 
			
		||||
        f.write(template.get_ollama_modelfile(tokenizer))
 | 
			
		||||
        logger.info_rank0(f"Saved ollama modelfile to {model_args.export_dir}.")
 | 
			
		||||
 | 
			
		||||
@ -119,6 +119,22 @@ def test_jinja_template(use_fast: bool):
 | 
			
		||||
    assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_ollama_modelfile():
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
 | 
			
		||||
    assert template.get_ollama_modelfile(tokenizer) == (
 | 
			
		||||
        "FROM .\n\n"
 | 
			
		||||
        'TEMPLATE """<|begin_of_text|>'
 | 
			
		||||
        "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}"
 | 
			
		||||
        '{{ range .Messages }}{{ if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Content }}'
 | 
			
		||||
        "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
 | 
			
		||||
        '{{ else if eq .Role "assistant" }}{{ .Content }}<|eot_id|>{{ end }}{{ end }}"""\n\n'
 | 
			
		||||
        'PARAMETER stop "<|eom_id|>"\n'
 | 
			
		||||
        'PARAMETER stop "<|eot_id|>"\n'
 | 
			
		||||
        "PARAMETER num_ctx 4096\n"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_get_stop_token_ids():
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user