From c9d3cc181a404a1355c0334aa3a4ab554b3efb74 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 12 Apr 2024 14:28:11 +0800 Subject: [PATCH] fix #3238 Former-commit-id: 232642a6215de9b57a1627f39b4efea127948a01 --- src/llmtuner/hparams/parser.py | 6 +++++- src/llmtuner/train/tuner.py | 3 +-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index 9264d1ee..4abd3f03 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -277,7 +277,11 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: _verify_model_args(model_args, finetuning_args) - model_args.device_map = "auto" + if model_args.export_dir is not None: + model_args.device_map = {"": "cpu"} + model_args.compute_dtype = torch.float32 + else: + model_args.device_map = "auto" return model_args, data_args, finetuning_args, generating_args diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index f6c2e16b..a8a2b8e9 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -65,8 +65,7 @@ def export_model(args: Optional[Dict[str, Any]] = None): if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model output_dtype = getattr(model.config, "torch_dtype", torch.float16) setattr(model.config, "torch_dtype", output_dtype) - for param in model.parameters(): - param.data = param.data.to(output_dtype) + model = model.to(output_dtype) model.save_pretrained( save_directory=model_args.export_dir,