support non-reenterent-gc & fix #6358

This commit is contained in:
hiyouga
2024-12-17 11:41:59 +00:00
parent 81815f053f
commit f319da6937
2 changed files with 7 additions and 1 deletions

View File

@@ -156,7 +156,9 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
)
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": model_args.use_reentrant_gc}
)
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info_rank0("Gradient checkpointing enabled.")