Remove manully set use_cache; torch_dtype is not str, save model as bfloat16 used to fail;

Former-commit-id: 9cdbd3bfc8be3f9adc799af8db9a254a47a577a2
This commit is contained in:
yhyu13 2024-01-21 11:12:15 +08:00
parent 9f11bdfe8a
commit f036b9c7ba

View File

@ -56,12 +56,11 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
setattr(model.config, "use_cache", True)
if getattr(model.config, "torch_dtype", None) == "bfloat16":
model = model.to(torch.bfloat16).to("cpu")
if hasattr(model.config, "torch_dtype"):
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
else:
model = model.to(torch.float16).to("cpu")
setattr(model.config, "torch_dtype", "float16")
model = model.to(torch.float32).to("cpu")
setattr(model.config, "torch_dtype", "float32")
model.save_pretrained(
save_directory=model_args.export_dir,