[data] fix ollama template (#6902)

* fix ollama template

* add meta info

* use half precision

Former-commit-id: e1a7c1242cd1e0a1ca9ee7d04377a53872488126
This commit is contained in:
hoshi-hiyouga 2025-02-11 22:43:09 +08:00 committed by GitHub
parent c6be9e242c
commit 197aa3baf4
4 changed files with 8 additions and 4 deletions

View File

@ -2,7 +2,6 @@
model_name_or_path: saves/llama3-8b/full/sft
template: llama3
trust_remote_code: true
infer_dtype: bfloat16
### export
export_dir: output/llama3_full_sft

View File

@ -321,10 +321,11 @@ class Template:
TODO: support function calling.
"""
modelfile = f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
modelfile = "# ollama modelfile auto-generated by llamafactory\n\n"
modelfile += f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
if self.default_system:
modelfile += f'SYSTEM system "{self.default_system}"\n\n'
modelfile += f'SYSTEM """{self.default_system}"""\n\n'
for stop_token_id in self.get_stop_token_ids(tokenizer):
modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'

View File

@ -22,6 +22,7 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
@ -117,7 +118,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
setattr(model.config, "torch_dtype", torch.float16)
else:
if model_args.infer_dtype == "auto":
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
output_dtype = getattr(model.config, "torch_dtype", torch.float32)
if output_dtype == torch.float32: # if infer_dtype is auto, try using half precision first
output_dtype = infer_optim_dtype(torch.bfloat16)
else:
output_dtype = getattr(torch, model_args.infer_dtype)

View File

@ -123,6 +123,7 @@ def test_ollama_modelfile():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
assert template.get_ollama_modelfile(tokenizer) == (
"# ollama modelfile auto-generated by llamafactory\n\n"
"FROM .\n\n"
'TEMPLATE """<|begin_of_text|>'
"{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}"