mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
fix zero2 high ram usage
Former-commit-id: 31a0564d4f4886db03250f2c6daee6e042dc3eb4
This commit is contained in:
parent
5f48c282d3
commit
8d4a5ebf6e
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||||
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
@ -43,8 +43,8 @@ def init_adapter(
|
|||||||
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
||||||
raise ValueError("You can only use lora for quantized models.")
|
raise ValueError("You can only use lora for quantized models.")
|
||||||
|
|
||||||
if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||||
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||||
cast_trainable_params_to_fp32 = False
|
cast_trainable_params_to_fp32 = False
|
||||||
else:
|
else:
|
||||||
logger.info("Upcasting trainable params to float32.")
|
logger.info("Upcasting trainable params to float32.")
|
||||||
|
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict
|
|||||||
import torch
|
import torch
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
||||||
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
@ -72,7 +72,7 @@ def patch_config(
|
|||||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||||
|
|
||||||
if deepspeed_config() is None and not is_fsdp_enabled(): # set dtype and device map if not use deepspeed or fsdp
|
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp
|
||||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||||
|
|
||||||
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
|
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user