[model] fix vit gradient checkpointing (#7830)

This commit is contained in:
hoshi-hiyouga 2025-04-23 22:48:48 +08:00 committed by GitHub
parent 2989d39239
commit 4fbdc65fcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -52,12 +52,12 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
) -> "torch.Tensor":
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
outputs = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
return outputs
@staticmethod
@torch.cuda.amp.custom_bwd
@ -66,7 +66,8 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
outputs = ctx.forward_function(hidden_states, *ctx.args)
output = outputs[0] if isinstance(outputs, tuple) else outputs
torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args)