[misc] fix accelerator (#9661)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yaowei Zheng
2025-12-25 02:11:04 +08:00
committed by GitHub
parent 6a2eafbae3
commit a754604c11
44 changed files with 396 additions and 448 deletions

View File

@@ -1,93 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from llamafactory.v1.accelerator.helper import ReduceOp, all_reduce, is_torch_cuda_available, is_torch_npu_available
from llamafactory.v1.utils.utils import find_available_port
def _dist_worker(rank, world_size):
if is_torch_cuda_available():
backend = "nccl"
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(rank)
elif is_torch_npu_available():
backend = "hccl"
device = torch.device(f"npu:{rank}")
torch.npu.set_device(rank)
else:
backend = "gloo"
device = torch.device("cpu")
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
)
# --------------------
# Test all_reduce SUM
# --------------------
y = torch.tensor(rank + 1.0, device=device)
y_sum = all_reduce(y.clone(), op=ReduceOp.SUM)
assert y_sum.item() == 3.0
# --------------------
# Test all_reduce MEAN
# --------------------
y_mean = all_reduce(y.clone(), op=ReduceOp.MEAN)
assert y_mean.item() == pytest.approx(1.5)
# --------------------
# Test all_reduce MAX
# --------------------
y_max = all_reduce(y.clone(), op=ReduceOp.MAX)
assert y_max.item() == 2.0
dist.destroy_process_group()
@pytest.mark.runs_on(["npu", "cuda"])
@pytest.mark.require_distributed(2)
def test_distributed_ops(monkeypatch):
monkeypatch.setenv("MASTER_ADDR", "127.0.0.1")
monkeypatch.setenv("MASTER_PORT", str(find_available_port()))
WORLD_SIZE = 2
mp.spawn(
_dist_worker,
args=(WORLD_SIZE,),
nprocs=WORLD_SIZE,
join=True,
)
@pytest.mark.runs_on(["npu", "cuda"])
@pytest.mark.require_distributed(4)
def test_required_multi():
# test require_distributed mark ok
pass
@pytest.mark.runs_on(["npu", "cuda"])
@pytest.mark.require_distributed(999)
def test_required_invalid():
# test require_distributed mark not ok,
raise RuntimeError(
"this case should not be run, please check whether the require_distributed mark implementation is correct"
)

View File

@@ -12,15 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch.multiprocessing as mp
from llamafactory.v1.accelerator.helper import ReduceOp
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.utils.env import find_available_port
from llamafactory.v1.utils.pytest import dist_env
def test_distributed_interface():
DistributedInterface()
assert DistributedInterface.get_rank() == int(os.getenv("RANK", "0"))
assert DistributedInterface.get_world_size() == int(os.getenv("WORLD_SIZE", "1"))
assert DistributedInterface.get_local_rank() == int(os.getenv("LOCAL_RANK", "0"))
assert DistributedInterface.get_local_world_size() == int(os.getenv("LOCAL_WORLD_SIZE", "1"))
def _all_reduce_tests(local_rank: int, world_size: int, master_port: int):
with dist_env(local_rank, world_size, master_port):
rank = DistributedInterface().get_rank()
world_size = DistributedInterface().get_world_size()
assert world_size == 2
y_sum = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.SUM)
assert y_sum == pytest.approx(3.0)
y_mean = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.MEAN)
assert y_mean == pytest.approx(1.5)
y_max = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.MAX)
assert y_max == pytest.approx(2.0)
z = DistributedInterface().all_gather(rank + 1.0)
assert z == pytest.approx([1.0, 2.0])
z = DistributedInterface().broadcast(rank + 1.0)
assert z == pytest.approx(1.0)
def test_all_device():
assert DistributedInterface().get_rank() == int(os.getenv("RANK", "0"))
assert DistributedInterface().get_world_size() == int(os.getenv("WORLD_SIZE", "1"))
assert DistributedInterface().get_local_rank() == int(os.getenv("LOCAL_RANK", "0"))
assert DistributedInterface().get_local_world_size() == int(os.getenv("LOCAL_WORLD_SIZE", "1"))
@pytest.mark.runs_on(["cuda", "npu"])
@pytest.mark.require_distributed(2)
def test_multi_device():
master_port = find_available_port()
mp.spawn(_all_reduce_tests, args=(2, master_port), nprocs=2)

View File

@@ -18,20 +18,17 @@ Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
from typing import Optional
import pytest
from pytest import Config, Item
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.train.test_utils import patch_valuehead_model
from llamafactory.v1.accelerator.helper import get_current_device, get_device_count
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
from llamafactory.v1.utils.env import is_env_enabled
from llamafactory.v1.utils.packages import is_transformers_version_greater_than
from llamafactory.v1.utils.utils import is_env_enabled
try:
CURRENT_DEVICE = get_current_device().type # cpu | cuda | npu
except Exception:
CURRENT_DEVICE = "cpu"
CURRENT_DEVICE = get_current_accelerator().type
def pytest_configure(config: Config):
@@ -67,26 +64,27 @@ def _handle_runs_on(items: list[Item]):
def _handle_slow_tests(items: list[Item]):
"""Skip slow tests unless RUN_SLOW is enabled."""
if not is_env_enabled("RUN_SLOW", "0"):
if not is_env_enabled("RUN_SLOW"):
skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
def _get_visible_devices_env():
def _get_visible_devices_env() -> Optional[str]:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
if CURRENT_DEVICE == "npu":
elif CURRENT_DEVICE == "npu":
return "ASCEND_RT_VISIBLE_DEVICES"
return None
else:
return None
def _handle_device_visibility(items: list[Item]):
"""Handle device visibility based on test markers."""
env_key = _get_visible_devices_env()
if env_key is None or CURRENT_DEVICE == "cpu":
if env_key is None or CURRENT_DEVICE in ("cpu", "mps"):
return
# Parse visible devices
@@ -122,7 +120,7 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
@pytest.fixture(autouse=True)
def _manage_distributed_env(request, monkeypatch):
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
env_key = _get_visible_devices_env()
if not env_key:
@@ -132,8 +130,7 @@ def _manage_distributed_env(request, monkeypatch):
old_value = os.environ.get(env_key)
marker = request.node.get_closest_marker("require_distributed")
if marker:
# Distributed test
if marker: # distributed test
required = marker.args[0] if marker.args else 2
specific_devices = marker.args[1] if len(marker.args) > 1 else None
@@ -143,16 +140,9 @@ def _manage_distributed_env(request, monkeypatch):
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
else:
# Non-distributed test
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else:
monkeypatch.setenv(env_key, "0")
@pytest.fixture
def fix_valuehead_cpu_loading():
"""Fix valuehead model loading."""
patch_valuehead_model()

View File

@@ -28,10 +28,10 @@ from transformers import AutoTokenizer
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.data_loader import DataLoader
from llamafactory.v1.core.trainer_utils.data_collator import (
DefaultCollator,
)
from llamafactory.v1.core.trainer_utils.data_loader import DataLoader
from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
from llamafactory.v1.utils.batching_queue import TextBatchingQueue

View File

@@ -12,57 +12,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import MagicMock, patch
import pytest
from transformers import AutoModelForCausalLM
from llamafactory.v1.accelerator.helper import get_current_accelerator
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
class TestKernelPlugin(unittest.TestCase):
@patch("torch.accelerator.current_accelerator")
def test_apply_kernel(self, mock_get_accelerator):
get_current_accelerator.cache_clear()
mock_device = MagicMock()
mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
apply_kernel(model, npu_rope.NpuRoPEKernel)
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
@pytest.fixture(autouse=True)
def clear_accelerator_cache():
get_current_accelerator.cache_clear()
class Test_Use_V1_Kernels(unittest.TestCase):
@patch("torch.accelerator.current_accelerator")
def test_use_v1_kernels(self, mock_get_accelerator):
get_current_accelerator.cache_clear()
mock_device = MagicMock()
mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device
@patch("torch.accelerator.current_accelerator")
def test_apply_kernel(mock_get_accelerator: MagicMock):
mock_device = MagicMock()
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels
apply_kernel(model, npu_rope.NpuRoPEKernel)
model = apply_available_kernels(model)
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
@patch("torch.accelerator.current_accelerator")
def test_apply_all_kernels(mock_get_accelerator: MagicMock):
get_current_accelerator.cache_clear()
mock_device = MagicMock()
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_available_kernels(model)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
assert model.model.layers[0].mlp.forward is not original_swiglu_forward