diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 3130d6d2..0c5c98ec 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -89,6 +89,9 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func): return gradient_checkpointing_func(func, *args, **kwargs) + if hasattr(gradient_checkpointing_func, "__self__"): # fix test case + custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__ + return custom_gradient_checkpointing_func diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py index ac500df2..9367eab2 100644 --- a/tests/model/model_utils/test_checkpointing.py +++ b/tests/model/model_utils/test_checkpointing.py @@ -54,7 +54,7 @@ def test_checkpointing_disable(): def test_unsloth_gradient_checkpointing(): model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS) for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()): - assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing" + assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing" # classmethod def test_upcast_layernorm():