From 0170ef83a6e30a74fcf292860a7afcca8ecd5007 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 17 Apr 2024 22:54:34 +0800 Subject: [PATCH] fix #3316 Former-commit-id: c9a477322df82fecdb268ed385e3e0c376c0baeb --- src/llmtuner/model/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 7e4430d1..17b09a60 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -1,3 +1,4 @@ +import inspect from enum import Enum, unique from functools import partial from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -129,7 +130,11 @@ def gradient_checkpointing_enable( return gradient_checkpointing_func(func, *args, **kwargs) - self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) + if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format + self.apply(partial(self._set_gradient_checkpointing, value=True)) + logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") + else: + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: