diff --git a/tests_v1/conftest.py b/tests_v1/conftest.py index bf1a4d76a..adb08d49f 100644 --- a/tests_v1/conftest.py +++ b/tests_v1/conftest.py @@ -120,6 +120,17 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]): _handle_device_visibility(items) +@pytest.fixture(scope="session", autouse=True) +def _set_env(): + # add project root dir to path for mp run + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + if project_root not in sys.path: + sys.path.insert(0, project_root) + + os.environ["PYTHONPATH"] = project_root + os.pathsep + os.getenv("PYTHONPATH", "") + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + @pytest.fixture(autouse=True) def _cleanup_distributed_state(): """Cleanup distributed state after each test.""" @@ -150,13 +161,6 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) - monkeypatch.setenv(env_key, devices_str) - # add project root dir to path for mp run - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - if project_root not in sys.path: - sys.path.insert(0, project_root) - - os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "") - else: # non-distributed test if old_value: visible_devices = [v for v in old_value.split(",") if v != ""] diff --git a/tests_v1/plugins/model_plugins/test_kernel_plugin.py b/tests_v1/plugins/model_plugins/test_kernel_plugin.py index 10e8e413a..334c89502 100644 --- a/tests_v1/plugins/model_plugins/test_kernel_plugin.py +++ b/tests_v1/plugins/model_plugins/test_kernel_plugin.py @@ -16,59 +16,59 @@ import sys from unittest.mock import MagicMock, patch import pytest +import torch.multiprocessing as mp from transformers import AutoModelForCausalLM -from llamafactory.v1.accelerator.helper import get_current_accelerator + +def _apply_kernel(rank) -> None: + with patch("torch.accelerator.current_accelerator") as mock_get_accelerator: + mock_device = MagicMock() + setattr(mock_device, "type", "npu") + mock_get_accelerator.return_value = mock_device + + # reload kernel modules to respect mocked accelerator + for k in list(sys.modules.keys()): + if k.startswith("llamafactory.v1.plugins.model_plugins.kernels"): + del sys.modules[k] + + from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels + + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3") + original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward + original_swiglu_forward = model.model.layers[0].mlp.forward + + model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm") + + assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__ + assert model.model.layers[0].mlp.forward.__func__ is original_swiglu_forward.__func__ -@pytest.fixture(autouse=True) -def clear_accelerator_cache(): - get_current_accelerator.cache_clear() +def _apply_all_kernels(rank) -> None: + with patch("torch.accelerator.current_accelerator") as mock_get_accelerator: + mock_device = MagicMock() + setattr(mock_device, "type", "npu") + mock_get_accelerator.return_value = mock_device + + # reload kernel modules to respect mocked accelerator + for k in list(sys.modules.keys()): + if k.startswith("llamafactory.v1.plugins.model_plugins.kernels"): + del sys.modules[k] + + from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels + + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3") + original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward + original_swiglu_forward = model.model.layers[0].mlp.forward + + model = apply_default_kernels(model=model, include_kernels=True) + + assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__ + assert model.model.layers[0].mlp.forward.__func__ is not original_swiglu_forward.__func__ -def reload_kernels(): - """Helper to reload kernel modules to respect mocked accelerator.""" - # Unload kernel interface and registry - keys_to_remove = [k for k in sys.modules if k.startswith("llamafactory.v1.plugins.model_plugins.kernels")] - for k in keys_to_remove: - del sys.modules[k] +def test_apply_kernel(): + mp.spawn(_apply_kernel) -@patch("torch.accelerator.current_accelerator") -def test_apply_kernel(mock_get_accelerator: MagicMock): - mock_device = MagicMock() - setattr(mock_device, "type", "npu") - mock_get_accelerator.return_value = mock_device - # Force reload of kernels with mocked accelerator - reload_kernels() - from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels - - # NOTE: use a special model to avoid contamination by other tests - model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") - original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward - original_swiglu_forward = model.model.layers[0].mlp.forward - model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm") - assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__ - assert model.model.layers[0].mlp.forward.__func__ is original_swiglu_forward.__func__ - - -@patch("torch.accelerator.current_accelerator") -def test_apply_all_kernels(mock_get_accelerator: MagicMock): - get_current_accelerator.cache_clear() - mock_device = MagicMock() - setattr(mock_device, "type", "npu") - mock_get_accelerator.return_value = mock_device - - # Force reload of kernels with mocked accelerator - reload_kernels() - from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels - - # NOTE: use a special model to avoid contamination by other tests - model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") - - original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward - original_swiglu_forward = model.model.layers[0].mlp.forward - - model = apply_default_kernels(model=model, include_kernels=True) - assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__ - assert model.model.layers[0].mlp.forward.__func__ is not original_swiglu_forward.__func__ +def test_apply_all_kernels(): + mp.spawn(_apply_all_kernels)