From f2b4237db19ff30e816d92554403d3cf96136601 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 15 May 2024 16:32:28 +0800 Subject: [PATCH] fix fsdp model loading Former-commit-id: 008e3b3b1075199d1a62d510a8e0f212207a06b9 --- src/llmtuner/model/patcher.py | 3 ++- src/llmtuner/model/utils/quantization.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index b28a23d0..8625f3e1 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -6,6 +6,7 @@ import torch from peft import PeftModel from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled from ..extras.logging import get_logger from ..extras.misc import infer_optim_dtype @@ -69,7 +70,7 @@ def patch_config( setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn init_kwargs["torch_dtype"] = model_args.compute_dtype - if not is_deepspeed_zero3_enabled(): + if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage if init_kwargs["low_cpu_mem_usage"]: if "device_map" not in init_kwargs and model_args.device_map: diff --git a/src/llmtuner/model/utils/quantization.py b/src/llmtuner/model/utils/quantization.py index 3cf159c1..95412e7c 100644 --- a/src/llmtuner/model/utils/quantization.py +++ b/src/llmtuner/model/utils/quantization.py @@ -7,6 +7,7 @@ import torch from datasets import load_dataset from transformers import BitsAndBytesConfig, GPTQConfig from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled from transformers.utils.versions import require_version from ...extras.constants import FILEEXT2TYPE @@ -133,7 +134,7 @@ def configure_quantization( bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora ) - if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto": + if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": if model_args.quantization_bit != 4: raise ValueError("Only 4-bit quantized model can use auto device map.")