From ee5853c5650b5999e96e9918d41a7bdbeb93e8a0 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 6 Jun 2024 03:14:23 +0800 Subject: [PATCH] Update model_args.py Former-commit-id: 09c0afd94a8a5f5b45a61b32c983d50e1b9e2941 --- src/llamafactory/hparams/model_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 99c02850..024bc2f8 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -145,9 +145,9 @@ class ModelArguments: default=1, metadata={"help": "The file shard size (in GB) of the exported model."}, ) - export_device: Literal["cpu", "cuda", "npu"] = field( + export_device: Literal["cpu", "auto"] = field( default="cpu", - metadata={"help": "The device used in model export, use cuda to avoid addmm errors; use npu/cuda to speed up exporting."}, + metadata={"help": "The device used in model export, use `auto` to accelerate exporting."}, ) export_quantization_bit: Optional[int] = field( default=None,