[misc] fix grad ckpt (#6931)

Former-commit-id: c31c63b41109e616997757ec2da6e0ab89ed3b6e
This commit is contained in:
hoshi-hiyouga 2025-02-13 23:27:51 +08:00 committed by GitHub
parent cd493b91de
commit 13e1b7ee2b

View File

@ -85,13 +85,18 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()):
has_grad = True
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
break # assume the first tensor is always the hidden states
return gradient_checkpointing_func(func, *args, **kwargs)
if has_grad:
return gradient_checkpointing_func(func, *args, **kwargs)
else:
return func(*args, **kwargs)
return custom_gradient_checkpointing_func