From 6d8ef03741a012dee283bfd6fc38e397589e7dcc Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 23 May 2024 23:32:45 +0800 Subject: [PATCH] fix oom issues in export Former-commit-id: 67ebc7b388c61b9d880c02d7fd217c29299fdf43 --- src/llamafactory/hparams/model_args.py | 2 +- src/llamafactory/hparams/parser.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 5885bb09..650d1c22 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -145,7 +145,7 @@ class ModelArguments: default=1, metadata={"help": "The file shard size (in GB) of the exported model."}, ) - export_device: str = field( + export_device: Literal["cpu", "cuda"] = field( default="cpu", metadata={"help": "The device used in model export, use cuda to avoid addmm errors."}, ) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 20f9a003..6311297e 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -328,8 +328,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: _verify_model_args(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args) - if model_args.export_dir is not None: - model_args.device_map = {"": torch.device(model_args.export_device)} + if model_args.export_dir is not None and model_args.export_device == "cpu": + model_args.device_map = {"": torch.device("cpu")} else: model_args.device_map = "auto"