diff --git a/examples/merge_lora/llama3_full_sft.yaml b/examples/merge_lora/llama3_full_sft.yaml index 3d5682a7..4e329fad 100644 --- a/examples/merge_lora/llama3_full_sft.yaml +++ b/examples/merge_lora/llama3_full_sft.yaml @@ -2,7 +2,6 @@ model_name_or_path: saves/llama3-8b/full/sft template: llama3 trust_remote_code: true -infer_dtype: bfloat16 ### export export_dir: output/llama3_full_sft diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 61cb27a1..ccd322c4 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -321,10 +321,11 @@ class Template: TODO: support function calling. """ - modelfile = f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n' + modelfile = "# ollama modelfile auto-generated by llamafactory\n\n" + modelfile += f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n' if self.default_system: - modelfile += f'SYSTEM system "{self.default_system}"\n\n' + modelfile += f'SYSTEM """{self.default_system}"""\n\n' for stop_token_id in self.get_stop_token_ids(tokenizer): modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n' diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 23a39774..73800694 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -22,6 +22,7 @@ from transformers import PreTrainedModel from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ..extras.misc import infer_optim_dtype from ..extras.packages import is_ray_available from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args from ..model import load_model, load_tokenizer @@ -117,7 +118,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: setattr(model.config, "torch_dtype", torch.float16) else: if model_args.infer_dtype == "auto": - output_dtype = getattr(model.config, "torch_dtype", torch.float16) + output_dtype = getattr(model.config, "torch_dtype", torch.float32) + if output_dtype == torch.float32: # if infer_dtype is auto, try using half precision first + output_dtype = infer_optim_dtype(torch.bfloat16) else: output_dtype = getattr(torch, model_args.infer_dtype) diff --git a/tests/data/test_template.py b/tests/data/test_template.py index 61923678..9f8fc835 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -123,6 +123,7 @@ def test_ollama_modelfile(): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) assert template.get_ollama_modelfile(tokenizer) == ( + "# ollama modelfile auto-generated by llamafactory\n\n" "FROM .\n\n" 'TEMPLATE """<|begin_of_text|>' "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}"