fix fsdp model loading

Former-commit-id: 008e3b3b1075199d1a62d510a8e0f212207a06b9
This commit is contained in:
hiyouga 2024-05-15 16:32:28 +08:00
parent 967b9c0a49
commit f2b4237db1
2 changed files with 4 additions and 2 deletions

View File

@ -6,6 +6,7 @@ import torch
from peft import PeftModel from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import infer_optim_dtype from ..extras.misc import infer_optim_dtype
@ -69,7 +70,7 @@ def patch_config(
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
init_kwargs["torch_dtype"] = model_args.compute_dtype init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if init_kwargs["low_cpu_mem_usage"]: if init_kwargs["low_cpu_mem_usage"]:
if "device_map" not in init_kwargs and model_args.device_map: if "device_map" not in init_kwargs and model_args.device_map:

View File

@ -7,6 +7,7 @@ import torch
from datasets import load_dataset from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig from transformers import BitsAndBytesConfig, GPTQConfig
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras.constants import FILEEXT2TYPE from ...extras.constants import FILEEXT2TYPE
@ -133,7 +134,7 @@ def configure_quantization(
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
) )
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto": if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4: if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.") raise ValueError("Only 4-bit quantized model can use auto device map.")