From bbe5ff05701cb85327510a5e8742822299c05349 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 7 Feb 2024 00:38:24 +0800 Subject: [PATCH] update gc kwargs Former-commit-id: 0ae9a16b9d13bc1093662aa0b9bd990400ec2646 --- src/llmtuner/model/patcher.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 672d09d7..bb774e08 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -227,7 +227,9 @@ def _prepare_model_for_training( if not getattr(model, "supports_gradient_checkpointing", False): logger.warning("Current model does not support gradient checkpointing.") else: - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) + # According to: https://github.com/huggingface/transformers/issues/28339 + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) model.enable_input_require_grads() model.config.use_cache = False # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.")