diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index a111a8c5..b7cf6cfb 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -10,6 +10,7 @@ from transformers import ( ) from transformers.utils import check_min_version from transformers.utils.versions import require_version +from transformers.trainer import is_deepspeed_zero3_enabled from transformers.modeling_utils import PretrainedConfig, PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizerBase from trl import AutoModelForCausalLMWithValueHead @@ -108,7 +109,7 @@ def load_model_and_tokenizer( model_to_load, config=config, torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16, - low_cpu_mem_usage=True, + low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), **config_kwargs )