mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
fix test case
Former-commit-id: b075b2971c6acb2c6039b36420a296f1f4e1b91b
This commit is contained in:
parent
158e0e1f63
commit
b6810b209a
@ -89,6 +89,9 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
|
|||||||
|
|
||||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||||
|
|
||||||
|
if hasattr(gradient_checkpointing_func, "__self__"): # fix test case
|
||||||
|
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
||||||
|
|
||||||
return custom_gradient_checkpointing_func
|
return custom_gradient_checkpointing_func
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ def test_checkpointing_disable():
|
|||||||
def test_unsloth_gradient_checkpointing():
|
def test_unsloth_gradient_checkpointing():
|
||||||
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
|
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
|
||||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||||
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
|
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing" # classmethod
|
||||||
|
|
||||||
|
|
||||||
def test_upcast_layernorm():
|
def test_upcast_layernorm():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user