Former-commit-id: e0a717aa3aa97ca5a5d6a2ad8c1711a1f92b001b
This commit is contained in:
hiyouga 2024-01-21 14:11:38 +08:00
parent 8c10530181
commit fb2d563be5

View File

@ -223,6 +223,7 @@ def _prepare_model_for_training(
logger.warning("Current model does not support gradient checkpointing.") logger.warning("Current model does not support gradient checkpointing.")
else: else:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model.enable_input_require_grads()
model.config.use_cache = False # turn off when gradient checkpointing is enabled model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.") logger.info("Gradient checkpointing enabled.")