[misc] support export ollama modelfile (#6899)

* support export ollama modelfile

* update config

* add system and num ctx

Former-commit-id: 9184a6e0ed7ff5f632c848f861bfa448c4cd06fc
This commit is contained in:
hoshi-hiyouga 2025-02-11 19:52:25 +08:00 committed by GitHub
parent 2e954d8fd2
commit c6be9e242c
14 changed files with 126 additions and 224 deletions

View File

@ -106,16 +106,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** model.
<details><summary>Full Changelog</summary>
[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
<details><summary>Full Changelog</summary>
[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.

View File

@ -108,16 +108,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 更新日志
[25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。
[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。
[25/01/31] 我们支持了 **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** 和 **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** 模型的微调。
<details><summary>展开日志</summary>
[25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
[25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR.
<details><summary>展开日志</summary>
[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。

View File

@ -170,6 +170,12 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
```
### Save Ollama modelfile
```bash
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
```
### Inferring LoRA Fine-Tuned Models
#### Batch Generation using vLLM Tensor Parallel

View File

@ -170,6 +170,12 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
```
### 保存 Ollama 配置文件
```bash
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
```
### 推理 LoRA 模型
#### 使用 vLLM+TP 批量推理

View File

@ -0,0 +1,11 @@
### model
model_name_or_path: saves/llama3-8b/full/sft
template: llama3
trust_remote_code: true
infer_dtype: bfloat16
### export
export_dir: output/llama3_full_sft
export_size: 5
export_device: cpu
export_legacy_format: false

View File

@ -4,9 +4,9 @@ template: llama3
trust_remote_code: true
### export
export_dir: models/llama3_gptq
export_dir: output/llama3_gptq
export_quantization_bit: 4
export_quantization_dataset: data/c4_demo.json
export_size: 2
export_size: 5
export_device: cpu
export_legacy_format: false

View File

@ -4,11 +4,10 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3
finetuning_type: lora
trust_remote_code: true
### export
export_dir: models/llama3_lora_sft
export_size: 2
export_dir: output/llama3_lora_sft
export_size: 5
export_device: cpu
export_legacy_format: false

View File

@ -4,11 +4,10 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
template: qwen2_vl
finetuning_type: lora
trust_remote_code: true
### export
export_dir: models/qwen2_vl_lora_sft
export_size: 2
export_dir: output/qwen2_vl_lora_sft
export_size: 5
export_device: cpu
export_legacy_format: false

View File

@ -1,112 +0,0 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import os
from typing import TYPE_CHECKING
import fire
from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from llamafactory.data.formatter import SLOTS
from llamafactory.data.template import Template
def _convert_slots_to_ollama(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
slot_items = []
for slot in slots:
if isinstance(slot, str):
slot_pieces = slot.split("{{content}}")
if slot_pieces[0]:
slot_items.append(slot_pieces[0])
if len(slot_pieces) > 1:
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append(slot_pieces[1])
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append(tokenizer.bos_token)
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append(tokenizer.eos_token)
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
return "".join(slot_items)
def _split_round_template(user_template_str: "str", template_obj: "Template", tokenizer: "PreTrainedTokenizer") -> tuple:
if template_obj.format_separator.apply():
format_separator = _convert_slots_to_ollama(template_obj.format_separator.apply(), tokenizer)
round_split_token_list = [tokenizer.eos_token + format_separator, tokenizer.eos_token,
format_separator, "{{ .Prompt }}"]
else:
round_split_token_list = [tokenizer.eos_token, "{{ .Prompt }}"]
for round_split_token in round_split_token_list:
round_split_templates = user_template_str.split(round_split_token)
if len(round_split_templates) >= 2:
user_round_template = "".join(round_split_templates[:-1])
assistant_round_template = round_split_templates[-1]
return user_round_template + round_split_token, assistant_round_template
return user_template_str, ""
def convert_template_obj_to_ollama(template_obj: "Template", tokenizer: "PreTrainedTokenizer") -> str:
ollama_template = ""
if template_obj.format_system:
ollama_template += "{{ if .System }}"
ollama_template += _convert_slots_to_ollama(template_obj.format_system.apply(), tokenizer, "{{ .System }}")
ollama_template += "{{ end }}"
user_template = _convert_slots_to_ollama(template_obj.format_user.apply(), tokenizer, "{{ .Prompt }}")
user_round_template, assistant_round_template = _split_round_template(user_template, template_obj, tokenizer)
ollama_template += "{{ if .Prompt }}"
ollama_template += user_round_template
ollama_template += "{{ end }}"
ollama_template += assistant_round_template
ollama_template += _convert_slots_to_ollama(template_obj.format_assistant.apply(), tokenizer, "{{ .Response }}")
return ollama_template
def export_ollama_modelfile(
model_name_or_path: str,
gguf_path: str,
template: str,
export_dir: str = "./ollama_model_file"
):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
template_obj = get_template_and_fix_tokenizer(tokenizer, name=template)
ollama_template = convert_template_obj_to_ollama(template_obj, tokenizer)
if not os.path.exists(export_dir):
os.mkdir(export_dir)
with codecs.open(os.path.join(export_dir, "Modelfile"), "w", encoding="utf-8") as outf:
outf.write("FROM {}".format(gguf_path) + "\n")
outf.write("TEMPLATE \"\"\"{}\"\"\"".format(ollama_template) + "\n")
if template_obj.stop_words:
for stop_word in template_obj.stop_words:
outf.write("PARAMETER stop \"{}\"".format(stop_word) + "\n")
elif not template_obj.efficient_eos:
outf.write("PARAMETER stop \"{}\"".format(tokenizer.eos_token) + "\n")
if __name__ == '__main__':
fire.Fire(export_ollama_modelfile)

View File

@ -1,87 +0,0 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer
from export_ollama_modelfile import convert_template_obj_to_ollama
def test_qwen2_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
template = get_template_and_fix_tokenizer(tokenizer, name="qwen")
ollama_template = convert_template_obj_to_ollama(template, tokenizer)
assert ollama_template == ("{{ if .System }}<|im_start|>system\n"
"{{ .System }}<|im_end|>\n"
"{{ end }}{{ if .Prompt }}<|im_start|>user\n"
"{{ .Prompt }}<|im_end|>\n"
"{{ end }}<|im_start|>assistant\n"
"{{ .Response }}<|im_end|>")
def test_yi_template():
tokenizer = AutoTokenizer.from_pretrained("01-ai/Yi-1.5-9B-Chat")
template = get_template_and_fix_tokenizer(tokenizer, name="yi")
ollama_template = convert_template_obj_to_ollama(template, tokenizer)
assert ollama_template == ("{{ if .System }}<|im_start|>system\n"
"{{ .System }}<|im_end|>\n"
"{{ end }}{{ if .Prompt }}<|im_start|>user\n"
"{{ .Prompt }}<|im_end|>\n"
"{{ end }}<|im_start|>assistant\n"
"{{ .Response }}<|im_end|>")
def test_llama2_template():
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
template = get_template_and_fix_tokenizer(tokenizer, name="llama2")
ollama_template = convert_template_obj_to_ollama(template, tokenizer)
assert ollama_template == ("{{ if .System }}<<SYS>>\n"
"{{ .System }}\n"
"<</SYS>>\n\n"
"{{ end }}{{ if .Prompt }}<s>[INST] {{ .Prompt }}{{ end }} [/INST]{{ .Response }}</s>")
def test_llama3_template():
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
ollama_template = convert_template_obj_to_ollama(template, tokenizer)
assert ollama_template == ("{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n"
"{{ .System }}<|eot_id|>{{ end }}"
"{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n"
"{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n"
"{{ .Response }}<|eot_id|>")
def test_phi3_template():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
template = get_template_and_fix_tokenizer(tokenizer, name="phi")
ollama_template = convert_template_obj_to_ollama(template, tokenizer)
assert ollama_template == ("{{ if .System }}<|system|>\n"
"{{ .System }}<|end|>\n"
"{{ end }}{{ if .Prompt }}<|user|>\n"
"{{ .Prompt }}<|end|>\n"
"{{ end }}<|assistant|>\n"
"{{ .Response }}<|end|>")
if __name__ == '__main__':
test_qwen2_template()
test_yi_template()
test_llama2_template()
test_llama3_template()
test_phi3_template()

View File

@ -239,11 +239,9 @@ class Template:
Returns the jinja template.
"""
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
system_message = self._convert_slots_to_jinja(
self.format_system.apply(), tokenizer, placeholder="system_message"
)
user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
jinja_template = ""
if prefix:
jinja_template += "{{ " + prefix + " }}"
@ -254,13 +252,13 @@ class Template:
jinja_template += (
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
"{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
"{% if system_message is defined %}{{ " + system + " }}{% endif %}"
"{% for message in loop_messages %}"
"{% set content = message['content'] %}"
"{% if message['role'] == 'user' %}"
"{{ " + user_message + " }}"
"{{ " + user + " }}"
"{% elif message['role'] == 'assistant' %}"
"{{ " + assistant_message + " }}"
"{{ " + assistant + " }}"
"{% endif %}"
"{% endfor %}"
)
@ -276,6 +274,64 @@ class Template:
except ValueError as e:
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
@staticmethod
def _convert_slots_to_ollama(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
) -> str:
r"""
Converts slots to ollama template.
"""
slot_items = []
for slot in slots:
if isinstance(slot, str):
slot_pieces = slot.split("{{content}}")
if slot_pieces[0]:
slot_items.append(slot_pieces[0])
if len(slot_pieces) > 1:
slot_items.append("{{ " + placeholder + " }}")
if slot_pieces[1]:
slot_items.append(slot_pieces[1])
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append(tokenizer.bos_token)
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append(tokenizer.eos_token)
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
return "".join(slot_items)
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama template.
"""
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content")
return (
f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}"
f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}"""
f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}"""
)
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama modelfile.
TODO: support function calling.
"""
modelfile = f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
if self.default_system:
modelfile += f'SYSTEM system "{self.default_system}"\n\n'
for stop_token_id in self.get_stop_token_ids(tokenizer):
modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'
modelfile += "PARAMETER num_ctx 4096\n"
return modelfile
@dataclass
class Llama2Template(Template):
@ -1020,7 +1076,7 @@ _register_template(
)
# copied from chatml template
# copied from minicpm_v template
_register_template(
name="minicpm_o",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),

View File

@ -87,7 +87,7 @@ class ExportArguments:
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: int = field(
default=1,
default=5,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(

View File

@ -104,7 +104,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
get_template_and_fix_tokenizer(tokenizer, data_args)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
@ -171,3 +171,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
except Exception as e:
logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
with open(os.path.join(model_args.export_dir, "Modelfile"), "w", encoding="utf-8") as f:
f.write(template.get_ollama_modelfile(tokenizer))
logger.info_rank0(f"Saved ollama modelfile to {model_args.export_dir}.")

View File

@ -119,6 +119,22 @@ def test_jinja_template(use_fast: bool):
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
def test_ollama_modelfile():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
assert template.get_ollama_modelfile(tokenizer) == (
"FROM .\n\n"
'TEMPLATE """<|begin_of_text|>'
"{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}"
'{{ range .Messages }}{{ if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Content }}'
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
'{{ else if eq .Role "assistant" }}{{ .Content }}<|eot_id|>{{ end }}{{ end }}"""\n\n'
'PARAMETER stop "<|eom_id|>"\n'
'PARAMETER stop "<|eot_id|>"\n'
"PARAMETER num_ctx 4096\n"
)
def test_get_stop_token_ids():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))