[misc] fix grad ckpt (#6931)

This commit is contained in:
hoshi-hiyouga
2025-02-13 23:27:51 +08:00
committed by GitHub
parent 797043d29c
commit c31c63b411

View File

@@ -85,13 +85,18 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()):
has_grad = True
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
break # assume the first tensor is always the hidden states
return gradient_checkpointing_func(func, *args, **kwargs)
if has_grad:
return gradient_checkpointing_func(func, *args, **kwargs)
else:
return func(*args, **kwargs)
return custom_gradient_checkpointing_func