[misc] fix accelerator (#9661)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yaowei Zheng
2025-12-25 02:11:04 +08:00
committed by GitHub
parent 6a2eafbae3
commit a754604c11
44 changed files with 396 additions and 448 deletions

View File

@@ -18,19 +18,17 @@ Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
from typing import Optional
import pytest
from pytest import Config, Item
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import patch_valuehead_model
try:
CURRENT_DEVICE = get_current_device().type # cpu | cuda | npu
except Exception:
CURRENT_DEVICE = "cpu"
CURRENT_DEVICE = get_current_device().type
def pytest_configure(config: Config):
@@ -66,26 +64,27 @@ def _handle_runs_on(items: list[Item]):
def _handle_slow_tests(items: list[Item]):
"""Skip slow tests unless RUN_SLOW is enabled."""
if not is_env_enabled("RUN_SLOW", "0"):
if not is_env_enabled("RUN_SLOW"):
skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
def _get_visible_devices_env():
def _get_visible_devices_env() -> Optional[str]:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
if CURRENT_DEVICE == "npu":
elif CURRENT_DEVICE == "npu":
return "ASCEND_RT_VISIBLE_DEVICES"
return None
else:
return None
def _handle_device_visibility(items: list[Item]):
"""Handle device visibility based on test markers."""
env_key = _get_visible_devices_env()
if env_key is None or CURRENT_DEVICE == "cpu":
if env_key is None or CURRENT_DEVICE in ("cpu", "mps"):
return
# Parse visible devices
@@ -121,7 +120,7 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
@pytest.fixture(autouse=True)
def _manage_distributed_env(request, monkeypatch):
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
env_key = _get_visible_devices_env()
if not env_key:
@@ -131,8 +130,7 @@ def _manage_distributed_env(request, monkeypatch):
old_value = os.environ.get(env_key)
marker = request.node.get_closest_marker("require_distributed")
if marker:
# Distributed test
if marker: # distributed test
required = marker.args[0] if marker.args else 2
specific_devices = marker.args[1] if len(marker.args) > 1 else None
@@ -142,8 +140,7 @@ def _manage_distributed_env(request, monkeypatch):
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
else:
# Non-distributed test
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")