From f9df6c17ed72d58ef5c4fe25f8debe75c2d5da40 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 17 Nov 2023 17:23:13 +0800 Subject: [PATCH] fix #1550 Former-commit-id: 1bbc1be95eedf0796c0b311568dff8c75f87dfbb --- src/llmtuner/model/loader.py | 7 ++++++- src/llmtuner/model/utils.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 20b9b5d4..4d2e1974 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -168,12 +168,17 @@ def load_model_and_tokenizer( config_kwargs["device_map"] = {"": get_current_device()} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + if is_deepspeed_zero3_enabled() or getattr(config, "model_type", None) == "chatglm": + low_cpu_mem_usage = False + else: + low_cpu_mem_usage = True + # Load pre-trained models (without valuehead) model = AutoModelForCausalLM.from_pretrained( model_to_load, config=config, torch_dtype=model_args.compute_dtype, - low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), + low_cpu_mem_usage=low_cpu_mem_usage, **config_kwargs ) diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 7badc905..e7445f8d 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -140,7 +140,7 @@ def prepare_model_for_training( model.get_input_embeddings().register_forward_hook(neftune_forward_hook) logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha)) - if use_gradient_checkpointing: + if use_gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False): if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: