[misc] fix grad ckpt (#6931)

Former-commit-id: deae1fc9a0bea5c8b8be1564cf9c81c9c02a0b3a
This commit is contained in:
hoshi-hiyouga 2025-02-13 23:27:51 +08:00 committed by GitHub
parent 5e5fc337f9
commit ed25e051a9

View File

@ -85,13 +85,18 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()):
has_grad = True
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
break # assume the first tensor is always the hidden states
return gradient_checkpointing_func(func, *args, **kwargs)
if has_grad:
return gradient_checkpointing_func(func, *args, **kwargs)
else:
return func(*args, **kwargs)
return custom_gradient_checkpointing_func