From 48fb0be1b920478f037100b0cf3eb6f840e94f84 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 16 Apr 2024 17:29:19 +0800 Subject: [PATCH] Update patcher.py Former-commit-id: a950f3b81de701f5f23ce3efa60ff0382bb40dfe --- src/llmtuner/model/patcher.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 563b1827..fb2835e8 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -133,7 +133,9 @@ def _configure_quantization( if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.") - init_kwargs["device_map"] = {"": get_current_device()} + if model_args.quantization_device_map != "auto": + init_kwargs["device_map"] = {"": get_current_device()} + quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") @@ -268,7 +270,6 @@ def _prepare_model_for_training( # According to: https://github.com/huggingface/transformers/issues/28339 model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) - # model.enable_input_require_grads() setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.")