diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index 2d57f448..c24e5eac 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -56,12 +56,11 @@ def export_model(args: Optional[Dict[str, Any]] = None): if not isinstance(model, PreTrainedModel): raise ValueError("The model is not a `PreTrainedModel`, export aborted.") - setattr(model.config, "use_cache", True) - if getattr(model.config, "torch_dtype", None) == torch.bfloat16: - model = model.to(torch.bfloat16).to("cpu") + if hasattr(model.config, "torch_dtype"): + model = model.to(getattr(model.config, "torch_dtype")).to("cpu") else: model = model.to(torch.float16).to("cpu") - setattr(model.config, "torch_dtype", "float16") + setattr(model.config, "torch_dtype", torch.float16) model.save_pretrained( save_directory=model_args.export_dir,