[v1] add init plugin (#9716)

This commit is contained in:
Yaowei Zheng
2026-01-04 20:51:46 +08:00
committed by GitHub
parent 81b8a50aa5
commit f60a6e3d01
14 changed files with 307 additions and 74 deletions

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLaMA-Factory test configuration.
"""LlamaFactory test configuration.
Contains shared fixtures, pytest configuration, and custom markers.
"""
@@ -22,6 +22,7 @@ import sys
import pytest
import torch
import torch.distributed as dist
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
@@ -109,17 +110,24 @@ def _handle_device_visibility(items: list[Item]):
def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"):
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
for item in items:
if "tests_v1" in str(item.fspath):
item.add_marker(skip_bc)
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
for item in items:
if "tests_v1" in str(item.fspath) and not is_transformers_version_greater_than("4.57.0"):
item.add_marker(skip_bc)
_handle_slow_tests(items)
_handle_runs_on(items)
_handle_device_visibility(items)
@pytest.fixture(autouse=True)
def _cleanup_distributed_state():
"""Cleanup distributed state after each test."""
yield
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(autouse=True)
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
@@ -155,6 +163,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else:
monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":