diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index decacbce..a1b7bec4 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from transformers import PreTrainedModel +from ..data import get_template_and_fix_tokenizer from ..extras.callbacks import LogCallback from ..extras.logging import get_logger from ..hparams import get_infer_args, get_train_args @@ -40,7 +41,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra def export_model(args: Optional[Dict[str, Any]] = None): - model_args, _, finetuning_args, _ = get_infer_args(args) + model_args, data_args, finetuning_args, _ = get_infer_args(args) if model_args.export_dir is None: raise ValueError("Please specify `export_dir`.") @@ -49,6 +50,7 @@ def export_model(args: Optional[Dict[str, Any]] = None): raise ValueError("Please merge adapters before quantizing the model.") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + get_template_and_fix_tokenizer(tokenizer, data_args.template) if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None: raise ValueError("Cannot merge adapters to a quantized model.")