[ci] using mp to run kernel test (#9754)

Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
浮梦
2026-01-13 19:43:59 +08:00
committed by GitHub
parent 958b9c3468
commit 9829ae0a77
2 changed files with 59 additions and 55 deletions

View File

@@ -120,6 +120,17 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
_handle_device_visibility(items) _handle_device_visibility(items)
@pytest.fixture(scope="session", autouse=True)
def _set_env():
# add project root dir to path for mp run
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.getenv("PYTHONPATH", "")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _cleanup_distributed_state(): def _cleanup_distributed_state():
"""Cleanup distributed state after each test.""" """Cleanup distributed state after each test."""
@@ -150,13 +161,6 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, devices_str) monkeypatch.setenv(env_key, devices_str)
# add project root dir to path for mp run
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
else: # non-distributed test else: # non-distributed test
if old_value: if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""] visible_devices = [v for v in old_value.split(",") if v != ""]

View File

@@ -16,59 +16,59 @@ import sys
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import torch.multiprocessing as mp
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from llamafactory.v1.accelerator.helper import get_current_accelerator
def _apply_kernel(rank) -> None:
with patch("torch.accelerator.current_accelerator") as mock_get_accelerator:
mock_device = MagicMock()
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
# reload kernel modules to respect mocked accelerator
for k in list(sys.modules.keys()):
if k.startswith("llamafactory.v1.plugins.model_plugins.kernels"):
del sys.modules[k]
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
assert model.model.layers[0].mlp.forward.__func__ is original_swiglu_forward.__func__
@pytest.fixture(autouse=True) def _apply_all_kernels(rank) -> None:
def clear_accelerator_cache(): with patch("torch.accelerator.current_accelerator") as mock_get_accelerator:
get_current_accelerator.cache_clear() mock_device = MagicMock()
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
# reload kernel modules to respect mocked accelerator
for k in list(sys.modules.keys()):
if k.startswith("llamafactory.v1.plugins.model_plugins.kernels"):
del sys.modules[k]
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_default_kernels(model=model, include_kernels=True)
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
assert model.model.layers[0].mlp.forward.__func__ is not original_swiglu_forward.__func__
def reload_kernels(): def test_apply_kernel():
"""Helper to reload kernel modules to respect mocked accelerator.""" mp.spawn(_apply_kernel)
# Unload kernel interface and registry
keys_to_remove = [k for k in sys.modules if k.startswith("llamafactory.v1.plugins.model_plugins.kernels")]
for k in keys_to_remove:
del sys.modules[k]
@patch("torch.accelerator.current_accelerator") def test_apply_all_kernels():
def test_apply_kernel(mock_get_accelerator: MagicMock): mp.spawn(_apply_all_kernels)
mock_device = MagicMock()
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
# Force reload of kernels with mocked accelerator
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
# NOTE: use a special model to avoid contamination by other tests
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
assert model.model.layers[0].mlp.forward.__func__ is original_swiglu_forward.__func__
@patch("torch.accelerator.current_accelerator")
def test_apply_all_kernels(mock_get_accelerator: MagicMock):
get_current_accelerator.cache_clear()
mock_device = MagicMock()
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
# Force reload of kernels with mocked accelerator
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
# NOTE: use a special model to avoid contamination by other tests
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_default_kernels(model=model, include_kernels=True)
assert model.model.layers[0].input_layernorm.forward.__func__ is not original_rmsnorm_forward.__func__
assert model.model.layers[0].mlp.forward.__func__ is not original_swiglu_forward.__func__