mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] fix ollama template (#6902)
* fix ollama template * add meta info * use half precision Former-commit-id: 1304bbea69d8c8ca57140017515dee7ae2ee6536
This commit is contained in:
		
							parent
							
								
									88eafd865b
								
							
						
					
					
						commit
						86063e27ea
					
				@ -2,7 +2,6 @@
 | 
			
		||||
model_name_or_path: saves/llama3-8b/full/sft
 | 
			
		||||
template: llama3
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
infer_dtype: bfloat16
 | 
			
		||||
 | 
			
		||||
### export
 | 
			
		||||
export_dir: output/llama3_full_sft
 | 
			
		||||
 | 
			
		||||
@ -321,10 +321,11 @@ class Template:
 | 
			
		||||
 | 
			
		||||
        TODO: support function calling.
 | 
			
		||||
        """
 | 
			
		||||
        modelfile = f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
 | 
			
		||||
        modelfile = "# ollama modelfile auto-generated by llamafactory\n\n"
 | 
			
		||||
        modelfile += f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
 | 
			
		||||
 | 
			
		||||
        if self.default_system:
 | 
			
		||||
            modelfile += f'SYSTEM system "{self.default_system}"\n\n'
 | 
			
		||||
            modelfile += f'SYSTEM """{self.default_system}"""\n\n'
 | 
			
		||||
 | 
			
		||||
        for stop_token_id in self.get_stop_token_ids(tokenizer):
 | 
			
		||||
            modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,7 @@ from transformers import PreTrainedModel
 | 
			
		||||
from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
 | 
			
		||||
from ..extras.misc import infer_optim_dtype
 | 
			
		||||
from ..extras.packages import is_ray_available
 | 
			
		||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
 | 
			
		||||
from ..model import load_model, load_tokenizer
 | 
			
		||||
@ -117,7 +118,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
 | 
			
		||||
        setattr(model.config, "torch_dtype", torch.float16)
 | 
			
		||||
    else:
 | 
			
		||||
        if model_args.infer_dtype == "auto":
 | 
			
		||||
            output_dtype = getattr(model.config, "torch_dtype", torch.float16)
 | 
			
		||||
            output_dtype = getattr(model.config, "torch_dtype", torch.float32)
 | 
			
		||||
            if output_dtype == torch.float32:  # if infer_dtype is auto, try using half precision first
 | 
			
		||||
                output_dtype = infer_optim_dtype(torch.bfloat16)
 | 
			
		||||
        else:
 | 
			
		||||
            output_dtype = getattr(torch, model_args.infer_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -123,6 +123,7 @@ def test_ollama_modelfile():
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
 | 
			
		||||
    assert template.get_ollama_modelfile(tokenizer) == (
 | 
			
		||||
        "# ollama modelfile auto-generated by llamafactory\n\n"
 | 
			
		||||
        "FROM .\n\n"
 | 
			
		||||
        'TEMPLATE """<|begin_of_text|>'
 | 
			
		||||
        "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}"
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user