mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-31 19:20:36 +08:00
[ci] add cuda workflow (#9682)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -18,8 +18,11 @@ Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
||||
|
||||
from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled
|
||||
@@ -70,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
|
||||
item.add_marker(skip_slow)
|
||||
|
||||
|
||||
def _get_visible_devices_env() -> str | None:
|
||||
def _get_visible_devices_env() -> Optional[str]:
|
||||
"""Return device visibility env var name."""
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
return "CUDA_VISIBLE_DEVICES"
|
||||
@@ -118,6 +121,14 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
|
||||
_handle_device_visibility(items)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_distributed_state():
|
||||
"""Cleanup distributed state after each test."""
|
||||
yield
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
|
||||
"""Set environment variables for distributed tests if specific devices are requested."""
|
||||
@@ -145,6 +156,10 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
|
||||
else:
|
||||
monkeypatch.setenv(env_key, "0")
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||
elif CURRENT_DEVICE == "npu":
|
||||
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user