This commit is contained in:
hiyouga
2024-01-24 16:19:18 +08:00
parent 51ad35b3c7
commit 2bc30763e9
3 changed files with 43 additions and 38 deletions

View File

@@ -56,7 +56,9 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
if hasattr(model.config, "torch_dtype"):
if getattr(model, "quantization_method", None):
model = model.to("cpu")
elif hasattr(model.config, "torch_dtype"):
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
else:
model = model.to(torch.float16).to("cpu")