[v1] add cli sampler (#9721)

This commit is contained in:
Yaowei Zheng
2026-01-06 23:31:27 +08:00
committed by GitHub
parent e944dc442c
commit ea0b4e2466
45 changed files with 1091 additions and 505 deletions

View File

@@ -18,7 +18,7 @@ Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
from typing import Optional
import sys
import pytest
import torch
@@ -73,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
item.add_marker(skip_slow)
def _get_visible_devices_env() -> Optional[str]:
def _get_visible_devices_env() -> str | None:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
@@ -149,6 +149,14 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
# add project root dir to path for mp run
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated
0.9.4.105
0.9.5.101