From 4fbdc65fcb599a97e4b0e611c857c5a2b095b69b Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 23 Apr 2025 22:48:48 +0800 Subject: [PATCH] [model] fix vit gradient checkpointing (#7830) --- src/llamafactory/model/model_utils/checkpointing.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 28e2a795..714aca03 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -52,12 +52,12 @@ def get_unsloth_gradient_checkpointing_func() -> Callable: ) -> "torch.Tensor": saved_hidden_states = hidden_states.to("cpu", non_blocking=True) with torch.no_grad(): - output = forward_function(hidden_states, *args) + outputs = forward_function(hidden_states, *args) ctx.save_for_backward(saved_hidden_states) ctx.forward_function = forward_function ctx.args = args - return output + return outputs @staticmethod @torch.cuda.amp.custom_bwd @@ -66,7 +66,8 @@ def get_unsloth_gradient_checkpointing_func() -> Callable: hidden_states = hidden_states.to("cuda", non_blocking=True).detach() hidden_states.requires_grad_(True) with torch.enable_grad(): - (output,) = ctx.forward_function(hidden_states, *ctx.args) + outputs = ctx.forward_function(hidden_states, *ctx.args) + output = outputs[0] if isinstance(outputs, tuple) else outputs torch.autograd.backward(output, grad_output) return (None, hidden_states.grad) + (None,) * len(ctx.args)