Former-commit-id: 942362d008
This commit is contained in:
hiyouga
2024-04-18 15:34:45 +08:00
parent e2e0bbde12
commit 9aa62ffb57
2 changed files with 3 additions and 2 deletions

View File

@@ -132,8 +132,9 @@ def gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else:
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)