diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 0434f426..20271173 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -153,9 +153,9 @@ class ModelArguments: default=1, metadata={"help": "The file shard size (in GB) of the exported model."}, ) - export_device: Literal["cpu", "cuda"] = field( + export_device: Literal["cpu", "auto"] = field( default="cpu", - metadata={"help": "The device used in model export, use cuda to avoid addmm errors."}, + metadata={"help": "The device used in model export, use `auto` to accelerate exporting."}, ) export_quantization_bit: Optional[int] = field( default=None, diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index 2f354011..7e1493c8 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -89,7 +89,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1) export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none") export_quantization_dataset = gr.Textbox(value="data/c4_demo.json") - export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu") + export_device = gr.Radio(choices=["cpu", "auto"], value="cpu") export_legacy_format = gr.Checkbox() with gr.Row():