mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 10:56:56 +08:00
@@ -51,6 +51,12 @@ def test_checkpointing_disable():
|
||||
assert getattr(module, "gradient_checkpointing") is False
|
||||
|
||||
|
||||
def test_unsloth_gradient_checkpointing():
|
||||
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
|
||||
|
||||
|
||||
def test_upcast_layernorm():
|
||||
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
|
||||
Reference in New Issue
Block a user