mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-16 00:28:10 +08:00
[misc] fix grad ckpt (#6931)
Former-commit-id: deae1fc9a0bea5c8b8be1564cf9c81c9c02a0b3a
This commit is contained in:
parent
5e5fc337f9
commit
ed25e051a9
@ -85,13 +85,18 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
|||||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||||
module: "torch.nn.Module" = func.__self__
|
module: "torch.nn.Module" = func.__self__
|
||||||
|
|
||||||
|
has_grad = False
|
||||||
if any(param.requires_grad for param in module.parameters()):
|
if any(param.requires_grad for param in module.parameters()):
|
||||||
|
has_grad = True
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||||
arg.requires_grad_(True)
|
arg.requires_grad_(True)
|
||||||
break # assume the first tensor is always the hidden states
|
break # assume the first tensor is always the hidden states
|
||||||
|
|
||||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
if has_grad:
|
||||||
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return custom_gradient_checkpointing_func
|
return custom_gradient_checkpointing_func
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user