mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[model] fix vit gradient checkpointing (#7830)
This commit is contained in:
parent
2989d39239
commit
4fbdc65fcb
@ -52,12 +52,12 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
) -> "torch.Tensor":
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
outputs = forward_function(hidden_states, *args)
|
||||
|
||||
ctx.save_for_backward(saved_hidden_states)
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
return output
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
@ -66,7 +66,8 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
hidden_states.requires_grad_(True)
|
||||
with torch.enable_grad():
|
||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||
outputs = ctx.forward_function(hidden_states, *ctx.args)
|
||||
output = outputs[0] if isinstance(outputs, tuple) else outputs
|
||||
|
||||
torch.autograd.backward(output, grad_output)
|
||||
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user