diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index bbd44ba4..916f1934 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -85,13 +85,18 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs): module: "torch.nn.Module" = func.__self__ + has_grad = False if any(param.requires_grad for param in module.parameters()): + has_grad = True for arg in args: if torch.is_tensor(arg) and torch.is_floating_point(arg): arg.requires_grad_(True) break # assume the first tensor is always the hidden states - return gradient_checkpointing_func(func, *args, **kwargs) + if has_grad: + return gradient_checkpointing_func(func, *args, **kwargs) + else: + return func(*args, **kwargs) return custom_gradient_checkpointing_func