From 554ca3d8dcc1845c6cbdd317fc020ce94a14120b Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 23 May 2024 23:32:45 +0800 Subject: [PATCH] fix oom issues in export Former-commit-id: b7ccc882a192aa1e25b1e5816f875ea304282412 --- src/llamafactory/hparams/model_args.py | 2 +- src/llamafactory/hparams/parser.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 5885bb09..650d1c22 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -145,7 +145,7 @@ class ModelArguments: default=1, metadata={"help": "The file shard size (in GB) of the exported model."}, ) - export_device: str = field( + export_device: Literal["cpu", "cuda"] = field( default="cpu", metadata={"help": "The device used in model export, use cuda to avoid addmm errors."}, ) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 20f9a003..6311297e 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -328,8 +328,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: _verify_model_args(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args) - if model_args.export_dir is not None: - model_args.device_map = {"": torch.device(model_args.export_device)} + if model_args.export_dir is not None and model_args.export_device == "cpu": + model_args.device_map = {"": torch.device("cpu")} else: model_args.device_map = "auto"