diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 672d09d7..bb774e08 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -227,7 +227,9 @@ def _prepare_model_for_training( if not getattr(model, "supports_gradient_checkpointing", False): logger.warning("Current model does not support gradient checkpointing.") else: - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) + # According to: https://github.com/huggingface/transformers/issues/28339 + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) model.enable_input_require_grads() model.config.use_cache = False # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.")