mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
[misc] fix grad ckpt (#6931)
Former-commit-id: deae1fc9a0bea5c8b8be1564cf9c81c9c02a0b3a
This commit is contained in:
parent
5e5fc337f9
commit
ed25e051a9
@ -85,13 +85,18 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
|
||||
has_grad = False
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
has_grad = True
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||
arg.requires_grad_(True)
|
||||
break # assume the first tensor is always the hidden states
|
||||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
if has_grad:
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return custom_gradient_checkpointing_func
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user