support non-reenterent-gc & fix #6358

Former-commit-id: f319da6937
This commit is contained in:
hiyouga
2024-12-17 11:41:59 +00:00
parent 4caf043cf8
commit 64bac4bc7e
2 changed files with 7 additions and 1 deletions

View File

@@ -156,7 +156,9 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
)
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": model_args.use_reentrant_gc}
)
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info_rank0("Gradient checkpointing enabled.")