mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[misc] support export ollama modelfile (#6899)
* support export ollama modelfile * update config * add system and num ctx Former-commit-id: 9184a6e0ed7ff5f632c848f861bfa448c4cd06fc
This commit is contained in:
parent
2e954d8fd2
commit
c6be9e242c
@ -106,16 +106,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
|
[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
|
||||||
|
|
||||||
[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** model.
|
[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** model.
|
||||||
|
|
||||||
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
|
[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
|
[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
|
||||||
|
|
||||||
[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
|
[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
|
||||||
|
|
||||||
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
|
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
|
||||||
|
10
README_zh.md
10
README_zh.md
@ -108,16 +108,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
|
[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。
|
||||||
|
|
||||||
[25/01/31] 我们支持了 **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** 和 **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** 模型的微调。
|
[25/01/31] 我们支持了 **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** 和 **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** 模型的微调。
|
||||||
|
|
||||||
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
[25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
|
[25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR.
|
[25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR.
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
|
||||||
|
|
||||||
[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。
|
|
||||||
|
|
||||||
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
|
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
|
||||||
|
|
||||||
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
|
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
|
||||||
|
@ -170,6 +170,12 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
|||||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Save Ollama modelfile
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
### Inferring LoRA Fine-Tuned Models
|
### Inferring LoRA Fine-Tuned Models
|
||||||
|
|
||||||
#### Batch Generation using vLLM Tensor Parallel
|
#### Batch Generation using vLLM Tensor Parallel
|
||||||
|
@ -170,6 +170,12 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
|||||||
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 保存 Ollama 配置文件
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
### 推理 LoRA 模型
|
### 推理 LoRA 模型
|
||||||
|
|
||||||
#### 使用 vLLM+TP 批量推理
|
#### 使用 vLLM+TP 批量推理
|
||||||
|
11
examples/merge_lora/llama3_full_sft.yaml
Normal file
11
examples/merge_lora/llama3_full_sft.yaml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: saves/llama3-8b/full/sft
|
||||||
|
template: llama3
|
||||||
|
trust_remote_code: true
|
||||||
|
infer_dtype: bfloat16
|
||||||
|
|
||||||
|
### export
|
||||||
|
export_dir: output/llama3_full_sft
|
||||||
|
export_size: 5
|
||||||
|
export_device: cpu
|
||||||
|
export_legacy_format: false
|
@ -4,9 +4,9 @@ template: llama3
|
|||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
|
||||||
### export
|
### export
|
||||||
export_dir: models/llama3_gptq
|
export_dir: output/llama3_gptq
|
||||||
export_quantization_bit: 4
|
export_quantization_bit: 4
|
||||||
export_quantization_dataset: data/c4_demo.json
|
export_quantization_dataset: data/c4_demo.json
|
||||||
export_size: 2
|
export_size: 5
|
||||||
export_device: cpu
|
export_device: cpu
|
||||||
export_legacy_format: false
|
export_legacy_format: false
|
||||||
|
@ -4,11 +4,10 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||||
template: llama3
|
template: llama3
|
||||||
finetuning_type: lora
|
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
|
||||||
### export
|
### export
|
||||||
export_dir: models/llama3_lora_sft
|
export_dir: output/llama3_lora_sft
|
||||||
export_size: 2
|
export_size: 5
|
||||||
export_device: cpu
|
export_device: cpu
|
||||||
export_legacy_format: false
|
export_legacy_format: false
|
||||||
|
@ -4,11 +4,10 @@
|
|||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
|
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
finetuning_type: lora
|
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
|
||||||
### export
|
### export
|
||||||
export_dir: models/qwen2_vl_lora_sft
|
export_dir: output/qwen2_vl_lora_sft
|
||||||
export_size: 2
|
export_size: 5
|
||||||
export_device: cpu
|
export_device: cpu
|
||||||
export_legacy_format: false
|
export_legacy_format: false
|
||||||
|
@ -1,112 +0,0 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import codecs
|
|
||||||
import os
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import fire
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from llamafactory.data import get_template_and_fix_tokenizer
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from transformers import PreTrainedTokenizer
|
|
||||||
from llamafactory.data.formatter import SLOTS
|
|
||||||
from llamafactory.data.template import Template
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_slots_to_ollama(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
|
||||||
slot_items = []
|
|
||||||
for slot in slots:
|
|
||||||
if isinstance(slot, str):
|
|
||||||
slot_pieces = slot.split("{{content}}")
|
|
||||||
if slot_pieces[0]:
|
|
||||||
slot_items.append(slot_pieces[0])
|
|
||||||
if len(slot_pieces) > 1:
|
|
||||||
slot_items.append(placeholder)
|
|
||||||
if slot_pieces[1]:
|
|
||||||
slot_items.append(slot_pieces[1])
|
|
||||||
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
|
|
||||||
if "bos_token" in slot and tokenizer.bos_token_id is not None:
|
|
||||||
slot_items.append(tokenizer.bos_token)
|
|
||||||
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
|
|
||||||
slot_items.append(tokenizer.eos_token)
|
|
||||||
elif isinstance(slot, dict):
|
|
||||||
raise ValueError("Dict is not supported.")
|
|
||||||
|
|
||||||
return "".join(slot_items)
|
|
||||||
|
|
||||||
|
|
||||||
def _split_round_template(user_template_str: "str", template_obj: "Template", tokenizer: "PreTrainedTokenizer") -> tuple:
|
|
||||||
if template_obj.format_separator.apply():
|
|
||||||
format_separator = _convert_slots_to_ollama(template_obj.format_separator.apply(), tokenizer)
|
|
||||||
round_split_token_list = [tokenizer.eos_token + format_separator, tokenizer.eos_token,
|
|
||||||
format_separator, "{{ .Prompt }}"]
|
|
||||||
else:
|
|
||||||
round_split_token_list = [tokenizer.eos_token, "{{ .Prompt }}"]
|
|
||||||
|
|
||||||
for round_split_token in round_split_token_list:
|
|
||||||
round_split_templates = user_template_str.split(round_split_token)
|
|
||||||
if len(round_split_templates) >= 2:
|
|
||||||
user_round_template = "".join(round_split_templates[:-1])
|
|
||||||
assistant_round_template = round_split_templates[-1]
|
|
||||||
return user_round_template + round_split_token, assistant_round_template
|
|
||||||
|
|
||||||
return user_template_str, ""
|
|
||||||
|
|
||||||
|
|
||||||
def convert_template_obj_to_ollama(template_obj: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
|
||||||
ollama_template = ""
|
|
||||||
if template_obj.format_system:
|
|
||||||
ollama_template += "{{ if .System }}"
|
|
||||||
ollama_template += _convert_slots_to_ollama(template_obj.format_system.apply(), tokenizer, "{{ .System }}")
|
|
||||||
ollama_template += "{{ end }}"
|
|
||||||
|
|
||||||
user_template = _convert_slots_to_ollama(template_obj.format_user.apply(), tokenizer, "{{ .Prompt }}")
|
|
||||||
user_round_template, assistant_round_template = _split_round_template(user_template, template_obj, tokenizer)
|
|
||||||
|
|
||||||
ollama_template += "{{ if .Prompt }}"
|
|
||||||
ollama_template += user_round_template
|
|
||||||
ollama_template += "{{ end }}"
|
|
||||||
ollama_template += assistant_round_template
|
|
||||||
|
|
||||||
ollama_template += _convert_slots_to_ollama(template_obj.format_assistant.apply(), tokenizer, "{{ .Response }}")
|
|
||||||
|
|
||||||
return ollama_template
|
|
||||||
|
|
||||||
|
|
||||||
def export_ollama_modelfile(
|
|
||||||
model_name_or_path: str,
|
|
||||||
gguf_path: str,
|
|
||||||
template: str,
|
|
||||||
export_dir: str = "./ollama_model_file"
|
|
||||||
):
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
|
||||||
template_obj = get_template_and_fix_tokenizer(tokenizer, name=template)
|
|
||||||
ollama_template = convert_template_obj_to_ollama(template_obj, tokenizer)
|
|
||||||
|
|
||||||
if not os.path.exists(export_dir):
|
|
||||||
os.mkdir(export_dir)
|
|
||||||
with codecs.open(os.path.join(export_dir, "Modelfile"), "w", encoding="utf-8") as outf:
|
|
||||||
outf.write("FROM {}".format(gguf_path) + "\n")
|
|
||||||
outf.write("TEMPLATE \"\"\"{}\"\"\"".format(ollama_template) + "\n")
|
|
||||||
|
|
||||||
if template_obj.stop_words:
|
|
||||||
for stop_word in template_obj.stop_words:
|
|
||||||
outf.write("PARAMETER stop \"{}\"".format(stop_word) + "\n")
|
|
||||||
elif not template_obj.efficient_eos:
|
|
||||||
outf.write("PARAMETER stop \"{}\"".format(tokenizer.eos_token) + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
fire.Fire(export_ollama_modelfile)
|
|
@ -1,87 +0,0 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from llamafactory.data import get_template_and_fix_tokenizer
|
|
||||||
from export_ollama_modelfile import convert_template_obj_to_ollama
|
|
||||||
|
|
||||||
|
|
||||||
def test_qwen2_template():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
|
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, name="qwen")
|
|
||||||
ollama_template = convert_template_obj_to_ollama(template, tokenizer)
|
|
||||||
|
|
||||||
assert ollama_template == ("{{ if .System }}<|im_start|>system\n"
|
|
||||||
"{{ .System }}<|im_end|>\n"
|
|
||||||
"{{ end }}{{ if .Prompt }}<|im_start|>user\n"
|
|
||||||
"{{ .Prompt }}<|im_end|>\n"
|
|
||||||
"{{ end }}<|im_start|>assistant\n"
|
|
||||||
"{{ .Response }}<|im_end|>")
|
|
||||||
|
|
||||||
|
|
||||||
def test_yi_template():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("01-ai/Yi-1.5-9B-Chat")
|
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, name="yi")
|
|
||||||
ollama_template = convert_template_obj_to_ollama(template, tokenizer)
|
|
||||||
|
|
||||||
assert ollama_template == ("{{ if .System }}<|im_start|>system\n"
|
|
||||||
"{{ .System }}<|im_end|>\n"
|
|
||||||
"{{ end }}{{ if .Prompt }}<|im_start|>user\n"
|
|
||||||
"{{ .Prompt }}<|im_end|>\n"
|
|
||||||
"{{ end }}<|im_start|>assistant\n"
|
|
||||||
"{{ .Response }}<|im_end|>")
|
|
||||||
|
|
||||||
|
|
||||||
def test_llama2_template():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, name="llama2")
|
|
||||||
ollama_template = convert_template_obj_to_ollama(template, tokenizer)
|
|
||||||
|
|
||||||
assert ollama_template == ("{{ if .System }}<<SYS>>\n"
|
|
||||||
"{{ .System }}\n"
|
|
||||||
"<</SYS>>\n\n"
|
|
||||||
"{{ end }}{{ if .Prompt }}<s>[INST] {{ .Prompt }}{{ end }} [/INST]{{ .Response }}</s>")
|
|
||||||
|
|
||||||
|
|
||||||
def test_llama3_template():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
|
|
||||||
ollama_template = convert_template_obj_to_ollama(template, tokenizer)
|
|
||||||
|
|
||||||
assert ollama_template == ("{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n"
|
|
||||||
"{{ .System }}<|eot_id|>{{ end }}"
|
|
||||||
"{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n"
|
|
||||||
"{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
"{{ .Response }}<|eot_id|>")
|
|
||||||
|
|
||||||
|
|
||||||
def test_phi3_template():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, name="phi")
|
|
||||||
ollama_template = convert_template_obj_to_ollama(template, tokenizer)
|
|
||||||
assert ollama_template == ("{{ if .System }}<|system|>\n"
|
|
||||||
"{{ .System }}<|end|>\n"
|
|
||||||
"{{ end }}{{ if .Prompt }}<|user|>\n"
|
|
||||||
"{{ .Prompt }}<|end|>\n"
|
|
||||||
"{{ end }}<|assistant|>\n"
|
|
||||||
"{{ .Response }}<|end|>")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_qwen2_template()
|
|
||||||
test_yi_template()
|
|
||||||
test_llama2_template()
|
|
||||||
test_llama3_template()
|
|
||||||
test_phi3_template()
|
|
@ -239,11 +239,9 @@ class Template:
|
|||||||
Returns the jinja template.
|
Returns the jinja template.
|
||||||
"""
|
"""
|
||||||
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
|
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
|
||||||
system_message = self._convert_slots_to_jinja(
|
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
|
||||||
self.format_system.apply(), tokenizer, placeholder="system_message"
|
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
|
||||||
)
|
assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
|
||||||
user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
|
|
||||||
assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
|
|
||||||
jinja_template = ""
|
jinja_template = ""
|
||||||
if prefix:
|
if prefix:
|
||||||
jinja_template += "{{ " + prefix + " }}"
|
jinja_template += "{{ " + prefix + " }}"
|
||||||
@ -254,13 +252,13 @@ class Template:
|
|||||||
jinja_template += (
|
jinja_template += (
|
||||||
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
|
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
|
||||||
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
|
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
|
||||||
"{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
"{% if system_message is defined %}{{ " + system + " }}{% endif %}"
|
||||||
"{% for message in loop_messages %}"
|
"{% for message in loop_messages %}"
|
||||||
"{% set content = message['content'] %}"
|
"{% set content = message['content'] %}"
|
||||||
"{% if message['role'] == 'user' %}"
|
"{% if message['role'] == 'user' %}"
|
||||||
"{{ " + user_message + " }}"
|
"{{ " + user + " }}"
|
||||||
"{% elif message['role'] == 'assistant' %}"
|
"{% elif message['role'] == 'assistant' %}"
|
||||||
"{{ " + assistant_message + " }}"
|
"{{ " + assistant + " }}"
|
||||||
"{% endif %}"
|
"{% endif %}"
|
||||||
"{% endfor %}"
|
"{% endfor %}"
|
||||||
)
|
)
|
||||||
@ -276,6 +274,64 @@ class Template:
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_slots_to_ollama(
|
||||||
|
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
|
||||||
|
) -> str:
|
||||||
|
r"""
|
||||||
|
Converts slots to ollama template.
|
||||||
|
"""
|
||||||
|
slot_items = []
|
||||||
|
for slot in slots:
|
||||||
|
if isinstance(slot, str):
|
||||||
|
slot_pieces = slot.split("{{content}}")
|
||||||
|
if slot_pieces[0]:
|
||||||
|
slot_items.append(slot_pieces[0])
|
||||||
|
if len(slot_pieces) > 1:
|
||||||
|
slot_items.append("{{ " + placeholder + " }}")
|
||||||
|
if slot_pieces[1]:
|
||||||
|
slot_items.append(slot_pieces[1])
|
||||||
|
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
|
||||||
|
if "bos_token" in slot and tokenizer.bos_token_id is not None:
|
||||||
|
slot_items.append(tokenizer.bos_token)
|
||||||
|
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
|
||||||
|
slot_items.append(tokenizer.eos_token)
|
||||||
|
elif isinstance(slot, dict):
|
||||||
|
raise ValueError("Dict is not supported.")
|
||||||
|
|
||||||
|
return "".join(slot_items)
|
||||||
|
|
||||||
|
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||||
|
r"""
|
||||||
|
Returns the ollama template.
|
||||||
|
"""
|
||||||
|
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
|
||||||
|
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
|
||||||
|
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
|
||||||
|
assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content")
|
||||||
|
return (
|
||||||
|
f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}"
|
||||||
|
f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}"""
|
||||||
|
f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||||
|
r"""
|
||||||
|
Returns the ollama modelfile.
|
||||||
|
|
||||||
|
TODO: support function calling.
|
||||||
|
"""
|
||||||
|
modelfile = f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
|
||||||
|
|
||||||
|
if self.default_system:
|
||||||
|
modelfile += f'SYSTEM system "{self.default_system}"\n\n'
|
||||||
|
|
||||||
|
for stop_token_id in self.get_stop_token_ids(tokenizer):
|
||||||
|
modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'
|
||||||
|
|
||||||
|
modelfile += "PARAMETER num_ctx 4096\n"
|
||||||
|
return modelfile
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Llama2Template(Template):
|
class Llama2Template(Template):
|
||||||
@ -1020,7 +1076,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from minicpm_v template
|
||||||
_register_template(
|
_register_template(
|
||||||
name="minicpm_o",
|
name="minicpm_o",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
@ -87,7 +87,7 @@ class ExportArguments:
|
|||||||
metadata={"help": "Path to the directory to save the exported model."},
|
metadata={"help": "Path to the directory to save the exported model."},
|
||||||
)
|
)
|
||||||
export_size: int = field(
|
export_size: int = field(
|
||||||
default=1,
|
default=5,
|
||||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||||
)
|
)
|
||||||
export_device: Literal["cpu", "auto"] = field(
|
export_device: Literal["cpu", "auto"] = field(
|
||||||
|
@ -104,7 +104,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
processor = tokenizer_module["processor"]
|
processor = tokenizer_module["processor"]
|
||||||
get_template_and_fix_tokenizer(tokenizer, data_args)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
|
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
|
||||||
|
|
||||||
if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
|
if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
|
||||||
@ -171,3 +171,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
|
logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
|
||||||
|
|
||||||
|
with open(os.path.join(model_args.export_dir, "Modelfile"), "w", encoding="utf-8") as f:
|
||||||
|
f.write(template.get_ollama_modelfile(tokenizer))
|
||||||
|
logger.info_rank0(f"Saved ollama modelfile to {model_args.export_dir}.")
|
||||||
|
@ -119,6 +119,22 @@ def test_jinja_template(use_fast: bool):
|
|||||||
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_modelfile():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||||
|
assert template.get_ollama_modelfile(tokenizer) == (
|
||||||
|
"FROM .\n\n"
|
||||||
|
'TEMPLATE """<|begin_of_text|>'
|
||||||
|
"{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}"
|
||||||
|
'{{ range .Messages }}{{ if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Content }}'
|
||||||
|
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
'{{ else if eq .Role "assistant" }}{{ .Content }}<|eot_id|>{{ end }}{{ end }}"""\n\n'
|
||||||
|
'PARAMETER stop "<|eom_id|>"\n'
|
||||||
|
'PARAMETER stop "<|eot_id|>"\n'
|
||||||
|
"PARAMETER num_ctx 4096\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_stop_token_ids():
|
def test_get_stop_token_ids():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user