[ci] add cuda workflow (#9682)

Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
浮梦
2025-12-29 20:03:00 +08:00
committed by GitHub
parent bb1ba31005
commit 1857fbdd6b
4 changed files with 123 additions and 3 deletions

View File

@@ -18,8 +18,11 @@ Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
from typing import Optional
import pytest
import torch
import torch.distributed as dist
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled
@@ -70,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
item.add_marker(skip_slow)
def _get_visible_devices_env() -> str | None:
def _get_visible_devices_env() -> Optional[str]:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
@@ -118,6 +121,14 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
_handle_device_visibility(items)
@pytest.fixture(autouse=True)
def _cleanup_distributed_state():
"""Cleanup distributed state after each test."""
yield
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(autouse=True)
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
@@ -145,6 +156,10 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else:
monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
@pytest.fixture