diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index aaedc1a8..e7ff0486 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -284,8 +284,9 @@ def patch_config( init_kwargs["torch_dtype"] = model_args.compute_dtype if not is_deepspeed_zero3_enabled(): - init_kwargs["device_map"] = {"": get_current_device()} init_kwargs["low_cpu_mem_usage"] = True + if is_trainable: + init_kwargs["device_map"] = {"": get_current_device()} def patch_model(