mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
refactor mllm param logic
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
@@ -39,16 +40,11 @@ TRAIN_ARGS = {
|
||||
}
|
||||
|
||||
|
||||
def test_checkpointing_enable():
|
||||
model = load_train_model(disable_gradient_checkpointing=False, **TRAIN_ARGS)
|
||||
@pytest.mark.parametrize("disable_gradient_checkpointing", [False, True])
|
||||
def test_vanilla_checkpointing(disable_gradient_checkpointing: bool):
|
||||
model = load_train_model(disable_gradient_checkpointing=disable_gradient_checkpointing, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert getattr(module, "gradient_checkpointing") is True
|
||||
|
||||
|
||||
def test_checkpointing_disable():
|
||||
model = load_train_model(disable_gradient_checkpointing=True, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert getattr(module, "gradient_checkpointing") is False
|
||||
assert getattr(module, "gradient_checkpointing") != disable_gradient_checkpointing
|
||||
|
||||
|
||||
def test_unsloth_gradient_checkpointing():
|
||||
|
||||
Reference in New Issue
Block a user