Merge pull request #2266 from yhyu13/fix_export_model_dtype

Remove manully set use_cache; torch_dtype is not str, save model as b…

Former-commit-id: ea6db7263157e7e3817d9398931cb0d583588695
This commit is contained in:
hoshi-hiyouga 2024-01-21 12:40:39 +08:00 committed by GitHub
commit 6318a6bbcf

View File

@ -56,12 +56,11 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if not isinstance(model, PreTrainedModel): if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.") raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
setattr(model.config, "use_cache", True) if hasattr(model.config, "torch_dtype"):
if getattr(model.config, "torch_dtype", None) == torch.bfloat16: model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
model = model.to(torch.bfloat16).to("cpu")
else: else:
model = model.to(torch.float16).to("cpu") model = model.to(torch.float16).to("cpu")
setattr(model.config, "torch_dtype", "float16") setattr(model.config, "torch_dtype", torch.float16)
model.save_pretrained( model.save_pretrained(
save_directory=model_args.export_dir, save_directory=model_args.export_dir,