mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
[model] fix vit gradient checkpointing (#7830)
This commit is contained in:
parent
2989d39239
commit
4fbdc65fcb
@ -52,12 +52,12 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
|
|||||||
) -> "torch.Tensor":
|
) -> "torch.Tensor":
|
||||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = forward_function(hidden_states, *args)
|
outputs = forward_function(hidden_states, *args)
|
||||||
|
|
||||||
ctx.save_for_backward(saved_hidden_states)
|
ctx.save_for_backward(saved_hidden_states)
|
||||||
ctx.forward_function = forward_function
|
ctx.forward_function = forward_function
|
||||||
ctx.args = args
|
ctx.args = args
|
||||||
return output
|
return outputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.cuda.amp.custom_bwd
|
@torch.cuda.amp.custom_bwd
|
||||||
@ -66,7 +66,8 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
|
|||||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||||
hidden_states.requires_grad_(True)
|
hidden_states.requires_grad_(True)
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
outputs = ctx.forward_function(hidden_states, *ctx.args)
|
||||||
|
output = outputs[0] if isinstance(outputs, tuple) else outputs
|
||||||
|
|
||||||
torch.autograd.backward(output, grad_output)
|
torch.autograd.backward(output, grad_output)
|
||||||
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user