This commit is contained in:
hiyouga
2023-11-17 17:23:13 +08:00
parent 999bc0ed93
commit 1bbc1be95e
2 changed files with 7 additions and 2 deletions

View File

@@ -140,7 +140,7 @@ def prepare_model_for_training(
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
if use_gradient_checkpointing:
if use_gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False):
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else: