fix fsdp model loading

This commit is contained in:
hiyouga
2024-05-15 16:32:28 +08:00
parent 11bf282dcc
commit 008e3b3b10
2 changed files with 4 additions and 2 deletions

View File

@@ -6,6 +6,7 @@ import torch
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from ..extras.misc import infer_optim_dtype
@@ -69,7 +70,7 @@ def patch_config(
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled():
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if init_kwargs["low_cpu_mem_usage"]:
if "device_map" not in init_kwargs and model_args.device_map: