diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index bc89cc8e..20b02219 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -97,6 +97,10 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."}, ) + offload_folder: str = field( + default="offload", + metadata={"help": "Path to offload model weights."}, + ) hf_hub_token: Optional[str] = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}, diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 401bdfe0..7335b1c1 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -291,6 +291,9 @@ def patch_config( if "device_map" not in init_kwargs: # quant models cannot use auto device map init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()} + if init_kwargs["device_map"] == "auto": + init_kwargs["offload_folder"] = model_args.offload_folder + def patch_model( model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool