From fb2d563be5f0b06b40f77cdb3421602ef8275b14 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 21 Jan 2024 14:11:38 +0800 Subject: [PATCH] fix #2268 Former-commit-id: e0a717aa3aa97ca5a5d6a2ad8c1711a1f92b001b --- src/llmtuner/model/patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 8a373760..5f67f618 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -223,6 +223,7 @@ def _prepare_model_for_training( logger.warning("Current model does not support gradient checkpointing.") else: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + model.enable_input_require_grads() model.config.use_cache = False # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.")