Former-commit-id: 232642a6215de9b57a1627f39b4efea127948a01
This commit is contained in:
hiyouga 2024-04-12 14:28:11 +08:00
parent 1ae6f0a5f3
commit c9d3cc181a
2 changed files with 6 additions and 3 deletions

View File

@ -277,7 +277,11 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
_verify_model_args(model_args, finetuning_args)
model_args.device_map = "auto"
if model_args.export_dir is not None:
model_args.device_map = {"": "cpu"}
model_args.compute_dtype = torch.float32
else:
model_args.device_map = "auto"
return model_args, data_args, finetuning_args, generating_args

View File

@ -65,8 +65,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
setattr(model.config, "torch_dtype", output_dtype)
for param in model.parameters():
param.data = param.data.to(output_dtype)
model = model.to(output_dtype)
model.save_pretrained(
save_directory=model_args.export_dir,