diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index e7ff0486..5a79387c 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -287,6 +287,8 @@ def patch_config( init_kwargs["low_cpu_mem_usage"] = True if is_trainable: init_kwargs["device_map"] = {"": get_current_device()} + elif model_args.export_dir is None: + init_kwargs["device_map"] = "auto" def patch_model(