From 8d4a5ebf6e6e78992bba64d201763ecd4ac0ad80 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sun, 19 May 2024 21:53:54 +0800 Subject: [PATCH] fix zero2 high ram usage Former-commit-id: 31a0564d4f4886db03250f2c6daee6e042dc3eb4 --- src/llamafactory/model/adapter.py | 6 +++--- src/llamafactory/model/patcher.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index e868afd6..f37f3bbb 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING import torch from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model -from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled +from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled from ..extras.logging import get_logger @@ -43,8 +43,8 @@ def init_adapter( if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None): raise ValueError("You can only use lora for quantized models.") - if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam: - logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.") + if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam: + logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.") cast_trainable_params_to_fp32 = False else: logger.info("Upcasting trainable params to float32.") diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 9297ef00..1a8ce607 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict import torch from peft import PeftModel from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available -from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled +from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled from ..extras.logging import get_logger @@ -72,7 +72,7 @@ def patch_config( # deepspeed zero3 is not compatible with low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) - if deepspeed_config() is None and not is_fsdp_enabled(): # set dtype and device map if not use deepspeed or fsdp + if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp init_kwargs["torch_dtype"] = model_args.compute_dtype if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True