diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 28e2a795..714aca03 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -52,12 +52,12 @@ def get_unsloth_gradient_checkpointing_func() -> Callable: ) -> "torch.Tensor": saved_hidden_states = hidden_states.to("cpu", non_blocking=True) with torch.no_grad(): - output = forward_function(hidden_states, *args) + outputs = forward_function(hidden_states, *args) ctx.save_for_backward(saved_hidden_states) ctx.forward_function = forward_function ctx.args = args - return output + return outputs @staticmethod @torch.cuda.amp.custom_bwd @@ -66,7 +66,8 @@ def get_unsloth_gradient_checkpointing_func() -> Callable: hidden_states = hidden_states.to("cuda", non_blocking=True).detach() hidden_states.requires_grad_(True) with torch.enable_grad(): - (output,) = ctx.forward_function(hidden_states, *ctx.args) + outputs = ctx.forward_function(hidden_states, *ctx.args) + output = outputs[0] if isinstance(outputs, tuple) else outputs torch.autograd.backward(output, grad_output) return (None, hidden_states.grad) + (None,) * len(ctx.args)