From 13e1b7ee2bdca483f4acf0c12ecd026c69cb6d66 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 13 Feb 2025 23:27:51 +0800 Subject: [PATCH] [misc] fix grad ckpt (#6931) Former-commit-id: c31c63b41109e616997757ec2da6e0ab89ed3b6e --- src/llamafactory/model/model_utils/checkpointing.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index bbd44ba4..916f1934 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -85,13 +85,18 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs): module: "torch.nn.Module" = func.__self__ + has_grad = False if any(param.requires_grad for param in module.parameters()): + has_grad = True for arg in args: if torch.is_tensor(arg) and torch.is_floating_point(arg): arg.requires_grad_(True) break # assume the first tensor is always the hidden states - return gradient_checkpointing_func(func, *args, **kwargs) + if has_grad: + return gradient_checkpointing_func(func, *args, **kwargs) + else: + return func(*args, **kwargs) return custom_gradient_checkpointing_func