mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
fix #1550
This commit is contained in:
@@ -140,7 +140,7 @@ def prepare_model_for_training(
|
||||
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
|
||||
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False):
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user