mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 21:52:51 +08:00
fix fsdp model loading
Former-commit-id: 008e3b3b1075199d1a62d510a8e0f212207a06b9
This commit is contained in:
parent
967b9c0a49
commit
f2b4237db1
@ -6,6 +6,7 @@ import torch
|
|||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import infer_optim_dtype
|
from ..extras.misc import infer_optim_dtype
|
||||||
@ -69,7 +70,7 @@ def patch_config(
|
|||||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
|
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
|
||||||
|
|
||||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||||
if not is_deepspeed_zero3_enabled():
|
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled():
|
||||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||||
if init_kwargs["low_cpu_mem_usage"]:
|
if init_kwargs["low_cpu_mem_usage"]:
|
||||||
if "device_map" not in init_kwargs and model_args.device_map:
|
if "device_map" not in init_kwargs and model_args.device_map:
|
||||||
|
@ -7,6 +7,7 @@ import torch
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import BitsAndBytesConfig, GPTQConfig
|
from transformers import BitsAndBytesConfig, GPTQConfig
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ...extras.constants import FILEEXT2TYPE
|
from ...extras.constants import FILEEXT2TYPE
|
||||||
@ -133,7 +134,7 @@ def configure_quantization(
|
|||||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
|
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
|
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
|
||||||
if model_args.quantization_bit != 4:
|
if model_args.quantization_bit != 4:
|
||||||
raise ValueError("Only 4-bit quantized model can use auto device map.")
|
raise ValueError("Only 4-bit quantized model can use auto device map.")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user