From a754604c1117592f948536a03f217e697ecc3c0c Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Thu, 25 Dec 2025 02:11:04 +0800 Subject: [PATCH] [misc] fix accelerator (#9661) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../model/model_utils/quantization.py | 5 +- src/llamafactory/v1/accelerator/helper.py | 145 ++++++++++------- src/llamafactory/v1/accelerator/interface.py | 147 ++++++++++-------- src/llamafactory/v1/core/model_loader.py | 2 +- .../core/{ => trainer_utils}/data_loader.py | 8 +- .../plugins/model_plugins/initialization.py | 0 .../v1/plugins/model_plugins/quantization.py | 0 .../v1/utils/{utils.py => env.py} | 10 +- src/llamafactory/v1/utils/pytest.py | 35 +++++ tests/conftest.py | 27 ++-- tests/data/processor/test_feedback.py | 2 +- tests/data/processor/test_pairwise.py | 2 +- tests/data/processor/test_processor_utils.py | 2 +- tests/data/processor/test_supervised.py | 8 +- tests/data/processor/test_unsupervised.py | 2 +- tests/data/test_collator.py | 4 +- tests/data/test_converter.py | 4 +- tests/data/test_formatter.py | 50 +++--- tests/data/test_loader.py | 6 +- tests/data/test_mm_plugin.py | 26 ++-- tests/data/test_template.py | 34 ++-- tests/e2e/test_chat.py | 4 +- tests/e2e/test_sglang.py | 4 +- tests/e2e/test_train.py | 4 +- tests/eval/test_eval_template.py | 4 +- tests/model/model_utils/test_add_tokens.py | 1 - tests/model/model_utils/test_attention.py | 1 - tests/model/model_utils/test_checkpointing.py | 4 - tests/model/model_utils/test_misc.py | 1 - tests/model/model_utils/test_packing.py | 1 - tests/model/model_utils/test_visual.py | 3 - tests/model/test_base.py | 2 - tests/model/test_freeze.py | 4 - tests/model/test_full.py | 3 - tests/model/test_lora.py | 7 - tests/model/test_pissa.py | 2 - tests/train/test_sft_trainer.py | 1 - tests/utils.py | 18 --- tests/version.txt | 2 +- tests_v1/accelerator/test_allreduce.py | 93 ----------- tests_v1/accelerator/test_interface.py | 47 +++++- tests_v1/conftest.py | 38 ++--- tests_v1/core/test_data_loader.py | 2 +- .../model_plugins/test_kernel_plugin.py | 79 +++++----- 44 files changed, 396 insertions(+), 448 deletions(-) rename src/llamafactory/v1/core/{ => trainer_utils}/data_loader.py (98%) create mode 100644 src/llamafactory/v1/plugins/model_plugins/initialization.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/quantization.py rename src/llamafactory/v1/utils/{utils.py => env.py} (77%) create mode 100644 src/llamafactory/v1/utils/pytest.py delete mode 100644 tests/utils.py delete mode 100644 tests_v1/accelerator/test_allreduce.py diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 417ab1112..6ad607db3 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -94,9 +94,8 @@ def configure_quantization( quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") - if ( - quant_method not in (QuantizationMethod.MXFP4 and QuantizationMethod.FP8) - and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) + if quant_method not in (QuantizationMethod.MXFP4, QuantizationMethod.FP8) and ( + is_deepspeed_zero3_enabled() or is_fsdp_enabled() ): # mxfp4 will dequant the model weights raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") diff --git a/src/llamafactory/v1/accelerator/helper.py b/src/llamafactory/v1/accelerator/helper.py index 9db5605a2..76ed3ad46 100644 --- a/src/llamafactory/v1/accelerator/helper.py +++ b/src/llamafactory/v1/accelerator/helper.py @@ -15,11 +15,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Utility functions used by the distributed interface. + +Including: +- Environment info (rank, world_size, local_rank, etc.) +- Accelerator info (device type, device count, etc.) +- Collective communication operations (all_gather, all_reduce, broadcast) +- Synchronize processes and ensure main-process-first execution order +""" + import os from contextlib import contextmanager from enum import Enum, unique -from functools import lru_cache -from typing import Optional +from functools import lru_cache, wraps +from typing import Callable, Optional import numpy as np import torch @@ -46,6 +55,22 @@ class ReduceOp(str, Enum): MIN = "min" +def requires_accelerator(fn): + """Decorator to check if torch.accelerator is available. + + Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError + """ + + @wraps(fn) + def wrapper(*args, **kwargs): + if not hasattr(torch, "accelerator"): + raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.") + + return fn(*args, **kwargs) + + return wrapper + + def is_distributed() -> bool: """Check if distributed environment is available.""" return os.getenv("RANK") is not None @@ -72,105 +97,105 @@ def get_local_world_size() -> int: @lru_cache +@requires_accelerator def get_current_accelerator(check_available: bool = True) -> torch.device: - """Get current accelerator. - - Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError - """ - if not hasattr(torch, "accelerator"): - raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.") - + """Get current accelerator.""" accelerator = torch.accelerator.current_accelerator(check_available=check_available) - if accelerator is None: - return torch.device(DeviceType.CPU.value) + return accelerator or torch.device(DeviceType.CPU.value) - return accelerator + +@lru_cache +@requires_accelerator +def get_device_count() -> int: + """Get the number of available devices.""" + return torch.accelerator.device_count() + + +@requires_accelerator +def synchronize() -> None: + """Synchronize all processes.""" + torch.accelerator.synchronize() + + +@requires_accelerator +def set_device() -> None: + """Set current accelerator.""" + torch.accelerator.set_device_index(get_local_rank()) def is_torch_cuda_available(): + """Check if CUDA is available.""" return get_current_accelerator().type == DeviceType.CUDA def is_torch_mps_available(): + """Check if MPS is available.""" return get_current_accelerator().type == DeviceType.MPS def is_torch_npu_available(): + """Check if NPU is available.""" return get_current_accelerator().type == DeviceType.NPU def is_torch_xpu_available(): + """Check if XPU is available.""" return get_current_accelerator().type == DeviceType.XPU -def get_current_device() -> "torch.device": - r"""Get the current available device.""" - if is_torch_xpu_available(): - device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0")) - elif is_torch_npu_available(): - device = "npu:{}".format(os.getenv("LOCAL_RANK", "0")) - elif is_torch_mps_available(): - device = "mps:{}".format(os.getenv("LOCAL_RANK", "0")) - elif is_torch_cuda_available(): - device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0")) +def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs) -> TensorLike: + """Operate tensorlike data on current accelerator.""" + device = get_current_accelerator() + is_tensor = isinstance(data, torch.Tensor) + is_ndarray = isinstance(data, np.ndarray) + + if is_tensor: + orig_device = data.device + data = data.to(device=device) + elif is_ndarray: + data = torch.from_numpy(data).to(device=device, dtype=torch.float) else: - device = "cpu" + data = torch.tensor(data, dtype=torch.float, device=device) - return torch.device(device) + result = fn(data, **kwargs) - -def get_device_count() -> int: - r"""Get the number of available devices.""" - if is_torch_xpu_available(): - return torch.xpu.device_count() - elif is_torch_npu_available(): - return torch.npu.device_count() - elif is_torch_mps_available(): - return torch.mps.device_count() - elif is_torch_cuda_available(): - return torch.cuda.device_count() + if is_tensor: + return result.to(orig_device) + elif is_ndarray: + return result.cpu().numpy() + elif result.numel() == 1: + return result.item() else: - return 0 + return result.tolist() def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - """Gathers the tensor from all ranks and concats them along the first dim.""" + """Gathers the tensor from all ranks and stacks them at the first dim.""" world_size = get_world_size() - device = get_current_accelerator() - output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device) + output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=tensor.device) dist.all_gather_into_tensor(output_tensor, tensor, group=group) - return output_tensor.view(-1, *tensor.size()[1:]) + return output_tensor.view(-1, *tensor.size()) -def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike: +def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> Tensor: """Performs all reduce in the given process group.""" - device = get_current_accelerator() - is_ndarray = isinstance(data, np.ndarray) - is_tensor = isinstance(data, torch.Tensor) - - if is_ndarray: - data = torch.from_numpy(data).to(device=device, dtype=torch.float) - elif not is_tensor: - data = torch.tensor(data, dtype=torch.float, device=device) - reduce_ops = { ReduceOp.MEAN: dist.ReduceOp.SUM, ReduceOp.SUM: dist.ReduceOp.SUM, ReduceOp.MAX: dist.ReduceOp.MAX, ReduceOp.MIN: dist.ReduceOp.MIN, } - dist.all_reduce(data, op=reduce_ops[op], group=group) + dist.all_reduce(tensor, op=reduce_ops[op], group=group) if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend - data /= dist.get_world_size(group=group) + tensor /= dist.get_world_size(group=group) - if is_tensor: - return data - elif is_ndarray: - return data.cpu().numpy() - elif data.numel() == 1: - return data.item() - else: - return data.tolist() + return tensor + + +def broadcast(tensor: Tensor, src: int = 0, group: Optional[ProcessGroup] = None) -> Tensor: + """Broadcasts the tensor from the src process to all other processes.""" + dist.broadcast(tensor, src=src, group=group) + return tensor @contextmanager diff --git a/src/llamafactory/v1/accelerator/interface.py b/src/llamafactory/v1/accelerator/interface.py index b24833186..810776342 100644 --- a/src/llamafactory/v1/accelerator/interface.py +++ b/src/llamafactory/v1/accelerator/interface.py @@ -15,26 +15,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""A unified interface for model parallelism and data parallelism. + +Supports model parallelism types: +- mp_replicate: Replicate model across multiple devices. +- mp_shard: Shard model across multiple devices. + +And data parallelism types: +- dp: Data parallelism. +- cp: Context parallelism. +""" + from dataclasses import dataclass from datetime import timedelta from enum import Enum from typing import Any, Optional -from torch.distributed import init_process_group +from torch.distributed import barrier, destroy_process_group, init_process_group from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike -from .helper import ( - ReduceOp, - all_gather, - all_reduce, - get_current_accelerator, - get_local_rank, - get_local_world_size, - get_rank, - get_world_size, - is_distributed, -) +from . import helper class Dim(str, Enum): @@ -60,24 +61,24 @@ class DistributedStrategy: """Context parallel size, default to 1.""" def __post_init__(self) -> None: - if not is_distributed(): + if not helper.is_distributed(): self.mp_shard_size = 1 elif self.mp_shard_size is None: - self.mp_shard_size = get_world_size() // self.mp_replicate_size - elif self.mp_replicate_size * self.mp_shard_size != get_world_size(): + self.mp_shard_size = helper.get_world_size() // self.mp_replicate_size + elif self.mp_replicate_size * self.mp_shard_size != helper.get_world_size(): raise ValueError( f"mp_replicate_size * mp_shard_size must equal to world_size, " - f"got {self.mp_replicate_size} * {self.mp_shard_size} != {get_world_size()}." + f"got {self.mp_replicate_size} * {self.mp_shard_size} != {helper.get_world_size()}." ) - if not is_distributed(): + if not helper.is_distributed(): self.dp_size = 1 elif self.dp_size is None: - self.dp_size = get_world_size() // self.cp_size - elif self.dp_size * self.cp_size != get_world_size(): + self.dp_size = helper.get_world_size() // self.cp_size + elif self.dp_size * self.cp_size != helper.get_world_size(): raise ValueError( f"dp_size * cp_size must equal to world_size, " - f"got {self.dp_size} * {self.cp_size} != {get_world_size()}." + f"got {self.dp_size} * {self.cp_size} != {helper.get_world_size()}." ) @property @@ -106,20 +107,6 @@ class DistributedInterface: _instance: Optional["DistributedInterface"] = None _initialized: bool = False - _is_distributed = is_distributed() - _rank = get_rank() - _world_size = get_world_size() - _local_rank = get_local_rank() - _local_world_size = get_local_world_size() - - strategy: Optional[DistributedStrategy] = None - """Distributed strategy.""" - model_device_mesh: Optional[DeviceMesh] = None - """Model parallel device mesh.""" - data_device_mesh: Optional[DeviceMesh] = None - """Data parallel device mesh.""" - current_accelerator = get_current_accelerator() - """Current accelerator.""" def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface": """Singleton pattern.""" @@ -132,6 +119,14 @@ class DistributedInterface: if self._initialized: return + self._is_distributed = helper.is_distributed() + self._rank = helper.get_rank() + self._world_size = helper.get_world_size() + self._local_rank = helper.get_local_rank() + self._local_world_size = helper.get_local_world_size() + self.current_accelerator = helper.get_current_accelerator() + self.device_count = helper.get_device_count() + if config is None: self.strategy = DistributedStrategy() timeout = 18000 @@ -145,6 +140,7 @@ class DistributedInterface: timeout = config.get("timeout", 18000) if self._is_distributed: + helper.set_device() init_process_group(timeout=timedelta(seconds=timeout)) self.model_device_mesh = init_device_mesh( device_type=self.current_accelerator.type, @@ -169,65 +165,84 @@ class DistributedInterface: f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}" ) - @classmethod - def get_device_mesh(cls, dim: Optional[Dim] = None) -> Optional[DeviceMesh]: + def get_device_mesh(self, dim: Optional[Dim] = None) -> Optional[DeviceMesh]: """Get device mesh for specified dimension.""" if dim is None: raise ValueError("dim must be specified.") - elif cls.model_device_mesh is None: + elif self.model_device_mesh is None: return None - elif dim in cls.strategy.data_mesh_dim_names: - return cls.data_device_mesh[dim.value] + elif dim in self.strategy.data_mesh_dim_names: + return self.data_device_mesh[dim.value] else: - return cls.model_device_mesh[dim.value] + return self.model_device_mesh[dim.value] - @classmethod - def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]: + def get_group(self, dim: Optional[Dim] = None) -> Optional[ProcessGroup]: """Get process group for specified dimension.""" - if cls.model_device_mesh is None or dim is None: + if self.model_device_mesh is None or dim is None: return None else: - return cls.get_device_mesh(dim).get_group() + return self.get_device_mesh(dim).get_group() - @classmethod - def get_rank(cls, dim: Optional[Dim] = None) -> int: + def get_rank(self, dim: Optional[Dim] = None) -> int: """Get parallel rank for specified dimension.""" - if cls.model_device_mesh is None: + if self.model_device_mesh is None: return 0 elif dim is None: - return cls._rank + return self._rank else: - return cls.get_device_mesh(dim).get_local_rank() + return self.get_device_mesh(dim).get_local_rank() - @classmethod - def get_world_size(cls, dim: Optional[Dim] = None) -> int: + def get_world_size(self, dim: Optional[Dim] = None) -> int: """Get parallel size for specified dimension.""" - if cls.model_device_mesh is None: + if self.model_device_mesh is None: return 1 elif dim is None: - return cls._world_size + return self._world_size else: - return cls.get_device_mesh(dim).size() + return self.get_device_mesh(dim).size() - @classmethod - def get_local_rank(cls) -> int: + def get_local_rank(self) -> int: """Get parallel local rank.""" - return cls._local_rank + return self._local_rank - @classmethod - def get_local_world_size(cls) -> int: + def get_local_world_size(self) -> int: """Get parallel local world size.""" - return cls._local_world_size + return self._local_world_size - @classmethod - def all_gather(cls, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor: + def all_gather(self, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor: """Gather tensor across specified parallel group.""" - return all_gather(data, cls.get_group(dim)) if cls.model_device_mesh is not None else data + if self.model_device_mesh is not None: + return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim)) + else: + return data - @classmethod - def all_reduce(cls, data: TensorLike, op: ReduceOp = ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP) -> TensorLike: + def all_reduce( + self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP + ) -> TensorLike: """Reduce tensor across specified parallel group.""" - return all_reduce(data, op, cls.get_group(dim)) if cls.model_device_mesh is not None else data + if self.model_device_mesh is not None: + return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim)) + else: + return data + + def broadcast(self, data: TensorLike, src: int = 0, dim: Optional[Dim] = Dim.DP) -> TensorLike: + """Broadcast tensor across specified parallel group.""" + if self.model_device_mesh is not None: + return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim)) + else: + return data + + def sync(self) -> None: + """Synchronize all processes.""" + helper.synchronize() + + def barrier(self) -> None: + """Barrier all processes.""" + barrier() + + def destroy(self) -> None: + """Destroy all processes.""" + destroy_process_group() if __name__ == "__main__": diff --git a/src/llamafactory/v1/core/model_loader.py b/src/llamafactory/v1/core/model_loader.py index 870f31f0f..e77d0ae27 100644 --- a/src/llamafactory/v1/core/model_loader.py +++ b/src/llamafactory/v1/core/model_loader.py @@ -97,7 +97,7 @@ class ModelLoader: self.args.model, config=self.model_config, dtype="auto", - device_map=DistributedInterface.current_accelerator, + device_map=DistributedInterface().current_accelerator, trust_remote_code=self.args.trust_remote_code, ) diff --git a/src/llamafactory/v1/core/data_loader.py b/src/llamafactory/v1/core/trainer_utils/data_loader.py similarity index 98% rename from src/llamafactory/v1/core/data_loader.py rename to src/llamafactory/v1/core/trainer_utils/data_loader.py index 1d580589a..a3bb9bdbe 100644 --- a/src/llamafactory/v1/core/data_loader.py +++ b/src/llamafactory/v1/core/trainer_utils/data_loader.py @@ -22,10 +22,10 @@ from typing import Optional from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler -from ..utils.batching_queue import BaseBatchingQueue -from ..utils.logging import get_logger -from ..utils.types import Processor, TorchDataset -from .trainer_utils.data_collator import DataCollator +from ...utils.batching_queue import BaseBatchingQueue +from ...utils.logging import get_logger +from ...utils.types import Processor, TorchDataset +from .data_collator import DataCollator logger = get_logger(__name__) diff --git a/src/llamafactory/v1/plugins/model_plugins/initialization.py b/src/llamafactory/v1/plugins/model_plugins/initialization.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/llamafactory/v1/plugins/model_plugins/quantization.py b/src/llamafactory/v1/plugins/model_plugins/quantization.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/llamafactory/v1/utils/utils.py b/src/llamafactory/v1/utils/env.py similarity index 77% rename from src/llamafactory/v1/utils/utils.py rename to src/llamafactory/v1/utils/env.py index 33c38826d..683cf0357 100644 --- a/src/llamafactory/v1/utils/utils.py +++ b/src/llamafactory/v1/utils/env.py @@ -17,7 +17,7 @@ import socket def find_available_port() -> int: - r"""Find an available port on the local machine.""" + """Find an available port on the local machine.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(("", 0)) port = sock.getsockname()[1] @@ -26,9 +26,5 @@ def find_available_port() -> int: def is_env_enabled(env_var: str, default: str = "0") -> bool: - r"""Check if the environment variable is enabled.""" - return os.getenv(env_var, default).lower() in ["true", "y", "1"] - - -if __name__ == "__main__": - print(find_available_port()) + """Check if the environment variable is enabled.""" + return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"] diff --git a/src/llamafactory/v1/utils/pytest.py b/src/llamafactory/v1/utils/pytest.py new file mode 100644 index 000000000..bbbaa08cf --- /dev/null +++ b/src/llamafactory/v1/utils/pytest.py @@ -0,0 +1,35 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager + + +@contextmanager +def dist_env(local_rank: int = 0, world_size: int = 1, master_port: int = 25595): + """Set distributed environment variables.""" + env_vars = { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": str(master_port), + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "LOCAL_WORLD_SIZE": str(world_size), + } + os.environ.update(env_vars) + try: + yield + finally: + for key in env_vars.keys(): + os.environ.pop(key, None) diff --git a/tests/conftest.py b/tests/conftest.py index cba9a30d4..71f7339f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,19 +18,17 @@ Contains shared fixtures, pytest configuration, and custom markers. """ import os +from typing import Optional import pytest -from pytest import Config, Item +from pytest import Config, FixtureRequest, Item, MonkeyPatch from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled from llamafactory.extras.packages import is_transformers_version_greater_than from llamafactory.train.test_utils import patch_valuehead_model -try: - CURRENT_DEVICE = get_current_device().type # cpu | cuda | npu -except Exception: - CURRENT_DEVICE = "cpu" +CURRENT_DEVICE = get_current_device().type def pytest_configure(config: Config): @@ -66,26 +64,27 @@ def _handle_runs_on(items: list[Item]): def _handle_slow_tests(items: list[Item]): """Skip slow tests unless RUN_SLOW is enabled.""" - if not is_env_enabled("RUN_SLOW", "0"): + if not is_env_enabled("RUN_SLOW"): skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)") for item in items: if "slow" in item.keywords: item.add_marker(skip_slow) -def _get_visible_devices_env(): +def _get_visible_devices_env() -> Optional[str]: """Return device visibility env var name.""" if CURRENT_DEVICE == "cuda": return "CUDA_VISIBLE_DEVICES" - if CURRENT_DEVICE == "npu": + elif CURRENT_DEVICE == "npu": return "ASCEND_RT_VISIBLE_DEVICES" - return None + else: + return None def _handle_device_visibility(items: list[Item]): """Handle device visibility based on test markers.""" env_key = _get_visible_devices_env() - if env_key is None or CURRENT_DEVICE == "cpu": + if env_key is None or CURRENT_DEVICE in ("cpu", "mps"): return # Parse visible devices @@ -121,7 +120,7 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]): @pytest.fixture(autouse=True) -def _manage_distributed_env(request, monkeypatch): +def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None: """Set environment variables for distributed tests if specific devices are requested.""" env_key = _get_visible_devices_env() if not env_key: @@ -131,8 +130,7 @@ def _manage_distributed_env(request, monkeypatch): old_value = os.environ.get(env_key) marker = request.node.get_closest_marker("require_distributed") - if marker: - # Distributed test + if marker: # distributed test required = marker.args[0] if marker.args else 2 specific_devices = marker.args[1] if len(marker.args) > 1 else None @@ -142,8 +140,7 @@ def _manage_distributed_env(request, monkeypatch): devices_str = ",".join(str(i) for i in range(required)) monkeypatch.setenv(env_key, devices_str) - else: - # Non-distributed test + else: # non-distributed test if old_value: visible_devices = [v for v in old_value.split(",") if v != ""] monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") diff --git a/tests/data/processor/test_feedback.py b/tests/data/processor/test_feedback.py index fecd79487..bcd85424f 100644 --- a/tests/data/processor/test_feedback.py +++ b/tests/data/processor/test_feedback.py @@ -42,7 +42,7 @@ TRAIN_ARGS = { } -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("num_samples", [16]) def test_feedback_data(num_samples: int): train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"] diff --git a/tests/data/processor/test_pairwise.py b/tests/data/processor/test_pairwise.py index d3b8dbced..6047afd02 100644 --- a/tests/data/processor/test_pairwise.py +++ b/tests/data/processor/test_pairwise.py @@ -51,7 +51,7 @@ def _convert_sharegpt_to_openai(messages: list[dict[str, str]]) -> list[dict[str return new_messages -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("num_samples", [16]) def test_pairwise_data(num_samples: int): train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"] diff --git a/tests/data/processor/test_processor_utils.py b/tests/data/processor/test_processor_utils.py index 256f5a6ef..a2a3b7ebe 100644 --- a/tests/data/processor/test_processor_utils.py +++ b/tests/data/processor/test_processor_utils.py @@ -18,7 +18,7 @@ import pytest from llamafactory.data.processor.processor_utils import infer_seqlen -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize( "test_input,test_output", [ diff --git a/tests/data/processor/test_supervised.py b/tests/data/processor/test_supervised.py index 903930b8e..0179c1a36 100644 --- a/tests/data/processor/test_supervised.py +++ b/tests/data/processor/test_supervised.py @@ -42,7 +42,7 @@ TRAIN_ARGS = { } -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("num_samples", [16]) def test_supervised_single_turn(num_samples: int): train_dataset = load_dataset_module(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)["train_dataset"] @@ -62,7 +62,7 @@ def test_supervised_single_turn(num_samples: int): assert train_dataset["input_ids"][index] == ref_input_ids -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("num_samples", [8]) def test_supervised_multi_turn(num_samples: int): train_dataset = load_dataset_module(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)[ @@ -76,7 +76,7 @@ def test_supervised_multi_turn(num_samples: int): assert train_dataset["input_ids"][index] == ref_input_ids -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("num_samples", [4]) def test_supervised_train_on_prompt(num_samples: int): train_dataset = load_dataset_module( @@ -91,7 +91,7 @@ def test_supervised_train_on_prompt(num_samples: int): assert train_dataset["labels"][index] == ref_ids -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("num_samples", [4]) def test_supervised_mask_history(num_samples: int): train_dataset = load_dataset_module( diff --git a/tests/data/processor/test_unsupervised.py b/tests/data/processor/test_unsupervised.py index 6566f1471..05f6cf9a0 100644 --- a/tests/data/processor/test_unsupervised.py +++ b/tests/data/processor/test_unsupervised.py @@ -46,7 +46,7 @@ TRAIN_ARGS = { } -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("num_samples", [16]) def test_unsupervised_data(num_samples: int): train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"] diff --git a/tests/data/test_collator.py b/tests/data/test_collator.py index 047354ab6..888030d08 100644 --- a/tests/data/test_collator.py +++ b/tests/data/test_collator.py @@ -29,7 +29,7 @@ from llamafactory.model import load_tokenizer TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3") -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_base_collator(): model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA3, "template": "default"}) tokenizer_module = load_tokenizer(model_args) @@ -73,7 +73,7 @@ def test_base_collator(): assert batch_input[k].eq(torch.tensor(expected_input[k])).all() -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_multimodal_collator(): model_args, data_args, *_ = get_infer_args( {"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"} diff --git a/tests/data/test_converter.py b/tests/data/test_converter.py index 23929c24e..6b411aed5 100644 --- a/tests/data/test_converter.py +++ b/tests/data/test_converter.py @@ -20,7 +20,7 @@ from llamafactory.data.parser import DatasetAttr from llamafactory.hparams import DataArguments -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_alpaca_converter(): dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset") data_args = DataArguments() @@ -41,7 +41,7 @@ def test_alpaca_converter(): } -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_sharegpt_converter(): dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset") data_args = DataArguments() diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index c2da6dbfb..969b0be32 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -38,19 +38,19 @@ TOOLS = [ ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_empty_formatter(): formatter = EmptyFormatter(slots=["\n"]) assert formatter.apply() == ["\n"] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_string_formatter(): formatter = StringFormatter(slots=["", "Human: {{content}}\nAssistant:"]) assert formatter.apply(content="Hi") == ["", "Human: Hi\nAssistant:"] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_function_formatter(): formatter = FunctionFormatter(slots=["{{content}}", ""], tool_format="default") tool_calls = json.dumps(FUNCTION) @@ -60,7 +60,7 @@ def test_function_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_multi_function_formatter(): formatter = FunctionFormatter(slots=["{{content}}", ""], tool_format="default") tool_calls = json.dumps([FUNCTION] * 2) @@ -71,7 +71,7 @@ def test_multi_function_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_default_tool_formatter(): formatter = ToolFormatter(tool_format="default") assert formatter.apply(content=json.dumps(TOOLS)) == [ @@ -90,14 +90,14 @@ def test_default_tool_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_default_tool_extractor(): formatter = ToolFormatter(tool_format="default") result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}""" assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_default_multi_tool_extractor(): formatter = ToolFormatter(tool_format="default") result = ( @@ -110,14 +110,14 @@ def test_default_multi_tool_extractor(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_glm4_function_formatter(): formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4") tool_calls = json.dumps(FUNCTION) assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_glm4_tool_formatter(): formatter = ToolFormatter(tool_format="glm4") assert formatter.apply(content=json.dumps(TOOLS)) == [ @@ -128,14 +128,14 @@ def test_glm4_tool_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_glm4_tool_extractor(): formatter = ToolFormatter(tool_format="glm4") result = """test_tool\n{"foo": "bar", "size": 10}\n""" assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_llama3_function_formatter(): formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3") tool_calls = json.dumps(FUNCTION) @@ -144,7 +144,7 @@ def test_llama3_function_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_llama3_multi_function_formatter(): formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3") tool_calls = json.dumps([FUNCTION] * 2) @@ -155,7 +155,7 @@ def test_llama3_multi_function_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_llama3_tool_formatter(): formatter = ToolFormatter(tool_format="llama3") date = datetime.now().strftime("%d %b %Y") @@ -169,14 +169,14 @@ def test_llama3_tool_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_llama3_tool_extractor(): formatter = ToolFormatter(tool_format="llama3") result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n""" assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_llama3_multi_tool_extractor(): formatter = ToolFormatter(tool_format="llama3") result = ( @@ -189,7 +189,7 @@ def test_llama3_multi_tool_extractor(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_mistral_function_formatter(): formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", ""], tool_format="mistral") tool_calls = json.dumps(FUNCTION) @@ -199,7 +199,7 @@ def test_mistral_function_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_mistral_multi_function_formatter(): formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", ""], tool_format="mistral") tool_calls = json.dumps([FUNCTION] * 2) @@ -211,7 +211,7 @@ def test_mistral_multi_function_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_mistral_tool_formatter(): formatter = ToolFormatter(tool_format="mistral") wrapped_tool = {"type": "function", "function": TOOLS[0]} @@ -220,14 +220,14 @@ def test_mistral_tool_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_mistral_tool_extractor(): formatter = ToolFormatter(tool_format="mistral") result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}""" assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_mistral_multi_tool_extractor(): formatter = ToolFormatter(tool_format="mistral") result = ( @@ -240,7 +240,7 @@ def test_mistral_multi_tool_extractor(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_qwen_function_formatter(): formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen") tool_calls = json.dumps(FUNCTION) @@ -249,7 +249,7 @@ def test_qwen_function_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_qwen_multi_function_formatter(): formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen") tool_calls = json.dumps([FUNCTION] * 2) @@ -260,7 +260,7 @@ def test_qwen_multi_function_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_qwen_tool_formatter(): formatter = ToolFormatter(tool_format="qwen") wrapped_tool = {"type": "function", "function": TOOLS[0]} @@ -274,14 +274,14 @@ def test_qwen_tool_formatter(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_qwen_tool_extractor(): formatter = ToolFormatter(tool_format="qwen") result = """\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n""" assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_qwen_multi_tool_extractor(): formatter = ToolFormatter(tool_format="qwen") result = ( diff --git a/tests/data/test_loader.py b/tests/data/test_loader.py index 9546cf4f4..907bda347 100644 --- a/tests/data/test_loader.py +++ b/tests/data/test_loader.py @@ -40,21 +40,21 @@ TRAIN_ARGS = { } -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_load_train_only(): dataset_module = load_dataset_module(**TRAIN_ARGS) assert dataset_module.get("train_dataset") is not None assert dataset_module.get("eval_dataset") is None -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_load_val_size(): dataset_module = load_dataset_module(val_size=0.1, **TRAIN_ARGS) assert dataset_module.get("train_dataset") is not None assert dataset_module.get("eval_dataset") is not None -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_load_eval_data(): dataset_module = load_dataset_module(eval_dataset=TINY_DATA, **TRAIN_ARGS) assert dataset_module.get("train_dataset") is not None diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 6efc9e431..bb416ed44 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -179,7 +179,7 @@ def _check_plugin( ) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_base_plugin(): tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3) base_plugin = get_mm_plugin(name="base") @@ -187,7 +187,7 @@ def test_base_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0") def test_gemma3_plugin(): @@ -210,7 +210,7 @@ def test_gemma3_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0") def test_internvl_plugin(): image_seqlen = 256 @@ -229,7 +229,7 @@ def test_internvl_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0") def test_llama4_plugin(): tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4) @@ -251,7 +251,7 @@ def test_llama4_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_llava_plugin(): image_seqlen = 576 tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf") @@ -265,7 +265,7 @@ def test_llava_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_llava_next_plugin(): image_seqlen = 1176 tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf") @@ -279,7 +279,7 @@ def test_llava_next_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_llava_next_video_plugin(): image_seqlen = 1176 tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf") @@ -293,7 +293,7 @@ def test_llava_next_video_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") def test_paligemma_plugin(): image_seqlen = 256 @@ -313,7 +313,7 @@ def test_paligemma_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0") def test_pixtral_plugin(): image_slice_height, image_slice_width = 2, 2 @@ -336,7 +336,7 @@ def test_pixtral_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0") def test_qwen2_omni_plugin(): image_seqlen, audio_seqlen = 4, 2 @@ -367,7 +367,7 @@ def test_qwen2_omni_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_qwen2_vl_plugin(): image_seqlen = 4 tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") @@ -384,7 +384,7 @@ def test_qwen2_vl_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0") def test_qwen3_vl_plugin(): frame_seqlen = 1 @@ -406,7 +406,7 @@ def test_qwen3_vl_plugin(): _check_plugin(**check_inputs) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0") def test_video_llava_plugin(): image_seqlen = 256 diff --git a/tests/data/test_template.py b/tests/data/test_template.py index 8a6ea3ea3..9f1018976 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -89,7 +89,7 @@ def _check_template( _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("use_fast", [True, False]) def test_encode_oneturn(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) @@ -105,7 +105,7 @@ def test_encode_oneturn(use_fast: bool): _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("use_fast", [True, False]) def test_encode_multiturn(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) @@ -127,7 +127,7 @@ def test_encode_multiturn(use_fast: bool): ) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("enable_thinking", [True, False, None]) @@ -154,7 +154,7 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("enable_thinking", [True, False, None]) @@ -184,7 +184,7 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t ) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("use_fast", [True, False]) def test_jinja_template(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) @@ -195,7 +195,7 @@ def test_jinja_template(use_fast: bool): assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_ollama_modelfile(): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) @@ -213,14 +213,14 @@ def test_ollama_modelfile(): ) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_get_stop_token_ids(): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009} -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.parametrize("use_fast", [True, False]) def test_gemma_template(use_fast: bool): @@ -234,7 +234,7 @@ def test_gemma_template(use_fast: bool): _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.parametrize("use_fast", [True, False]) def test_gemma2_template(use_fast: bool): @@ -248,7 +248,7 @@ def test_gemma2_template(use_fast: bool): _check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.parametrize("use_fast", [True, False]) def test_llama3_template(use_fast: bool): @@ -262,7 +262,7 @@ def test_llama3_template(use_fast: bool): _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize( "use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))] ) @@ -284,7 +284,7 @@ def test_llama4_template(use_fast: bool): pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")), ], ) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_phi4_template(use_fast: bool): prompt_str = ( f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>" @@ -296,7 +296,7 @@ def test_phi4_template(use_fast: bool): _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.parametrize("use_fast", [True, False]) def test_qwen2_5_template(use_fast: bool): @@ -311,7 +311,7 @@ def test_qwen2_5_template(use_fast: bool): _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False]) def test_qwen3_template(use_fast: bool, cot_messages: bool): @@ -331,7 +331,7 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool): _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_parse_llama3_template(): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN) template = parse_template(tokenizer) @@ -345,7 +345,7 @@ def test_parse_llama3_template(): assert template.default_system == "" -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") def test_parse_qwen_template(): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN) @@ -358,7 +358,7 @@ def test_parse_qwen_template(): assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") def test_parse_qwen3_template(): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN) diff --git a/tests/e2e/test_chat.py b/tests/e2e/test_chat.py index b05c79626..e33f32c56 100644 --- a/tests/e2e/test_chat.py +++ b/tests/e2e/test_chat.py @@ -37,13 +37,13 @@ MESSAGES = [ EXPECTED_RESPONSE = "_rho" -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_chat(): chat_model = ChatModel(INFER_ARGS) assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_stream_chat(): chat_model = ChatModel(INFER_ARGS) response = "" diff --git a/tests/e2e/test_sglang.py b/tests/e2e/test_sglang.py index 8db1703a6..7182ed382 100644 --- a/tests/e2e/test_sglang.py +++ b/tests/e2e/test_sglang.py @@ -39,7 +39,7 @@ MESSAGES = [ ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cuda"]) @pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed") def test_chat(): r"""Test the SGLang engine's basic chat functionality.""" @@ -49,7 +49,7 @@ def test_chat(): print(response.response_text) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cuda"]) @pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed") def test_stream_chat(): r"""Test the SGLang engine's streaming chat functionality.""" diff --git a/tests/e2e/test_train.py b/tests/e2e/test_train.py index c405dc976..dabb888b9 100644 --- a/tests/e2e/test_train.py +++ b/tests/e2e/test_train.py @@ -49,7 +49,7 @@ INFER_ARGS = { OS_NAME = os.getenv("OS_NAME", "") -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize( "stage,dataset", [ @@ -66,7 +66,7 @@ def test_run_exp(stage: str, dataset: str): assert os.path.exists(output_dir) -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_export(): export_dir = os.path.join("output", "llama3_export") export_model({"export_dir": export_dir, **INFER_ARGS}) diff --git a/tests/eval/test_eval_template.py b/tests/eval/test_eval_template.py index 91ad869ac..783d0b9e3 100644 --- a/tests/eval/test_eval_template.py +++ b/tests/eval/test_eval_template.py @@ -17,7 +17,7 @@ import pytest from llamafactory.eval.template import get_eval_template -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_eval_template_en(): support_set = [ { @@ -56,7 +56,7 @@ def test_eval_template_en(): ] -@pytest.mark.runs_on(["cpu"]) +@pytest.mark.runs_on(["cpu", "mps"]) def test_eval_template_zh(): support_set = [ { diff --git a/tests/model/model_utils/test_add_tokens.py b/tests/model/model_utils/test_add_tokens.py index 771c67480..cb1c414ab 100644 --- a/tests/model/model_utils/test_add_tokens.py +++ b/tests/model/model_utils/test_add_tokens.py @@ -25,7 +25,6 @@ TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3") UNUSED_TOKEN = "<|UNUSED_TOKEN|>" -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.parametrize("special_tokens", [False, True]) def test_add_tokens(special_tokens: bool): if special_tokens: diff --git a/tests/model/model_utils/test_attention.py b/tests/model/model_utils/test_attention.py index 2cd879702..075caeaee 100644 --- a/tests/model/model_utils/test_attention.py +++ b/tests/model/model_utils/test_attention.py @@ -39,7 +39,6 @@ INFER_ARGS = { } -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.xfail(is_transformers_version_greater_than("4.48"), reason="Attention refactor.") def test_attention(): attention_available = ["disabled"] diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py index 63039d821..2402e6fb7 100644 --- a/tests/model/model_utils/test_checkpointing.py +++ b/tests/model/model_utils/test_checkpointing.py @@ -39,7 +39,6 @@ TRAIN_ARGS = { } -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.parametrize("disable_gradient_checkpointing", [False, True]) def test_vanilla_checkpointing(disable_gradient_checkpointing: bool): model = load_train_model(disable_gradient_checkpointing=disable_gradient_checkpointing, **TRAIN_ARGS) @@ -47,14 +46,12 @@ def test_vanilla_checkpointing(disable_gradient_checkpointing: bool): assert getattr(module, "gradient_checkpointing") != disable_gradient_checkpointing -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_unsloth_gradient_checkpointing(): model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS) for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()): assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing" -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_upcast_layernorm(): model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS) for name, param in model.named_parameters(): @@ -62,7 +59,6 @@ def test_upcast_layernorm(): assert param.dtype == torch.float32 -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_upcast_lmhead_output(): model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS) inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device()) diff --git a/tests/model/model_utils/test_misc.py b/tests/model/model_utils/test_misc.py index cc0dcc54c..b2c8b3bf9 100644 --- a/tests/model/model_utils/test_misc.py +++ b/tests/model/model_utils/test_misc.py @@ -24,7 +24,6 @@ from llamafactory.model.model_utils.misc import find_expanded_modules HF_TOKEN = os.getenv("HF_TOKEN") -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") def test_expanded_modules(): config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") diff --git a/tests/model/model_utils/test_packing.py b/tests/model/model_utils/test_packing.py index 3302154ec..81e0d66a5 100644 --- a/tests/model/model_utils/test_packing.py +++ b/tests/model/model_utils/test_packing.py @@ -18,7 +18,6 @@ import torch from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.parametrize( "attention_mask,golden_seq_lens", [ diff --git a/tests/model/model_utils/test_visual.py b/tests/model/model_utils/test_visual.py index 003394386..fc53b69c2 100644 --- a/tests/model/model_utils/test_visual.py +++ b/tests/model/model_utils/test_visual.py @@ -23,7 +23,6 @@ from llamafactory.hparams import FinetuningArguments, ModelArguments from llamafactory.model.adapter import init_adapter -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.parametrize("freeze_vision_tower", (False, True)) @pytest.mark.parametrize("freeze_multi_modal_projector", (False, True)) @pytest.mark.parametrize("freeze_language_model", (False, True)) @@ -49,7 +48,6 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo assert param.requires_grad != freeze_language_model -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.parametrize("freeze_vision_tower,freeze_language_model", ((False, False), (False, True), (True, False))) def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool): model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct") @@ -82,7 +80,6 @@ def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool): assert (merger_param_name in trainable_params) is False -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_visual_model_save_load(): # check VLM's state dict: https://github.com/huggingface/transformers/pull/38385 model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct") diff --git a/tests/model/test_base.py b/tests/model/test_base.py index bbdde32fa..14afff633 100644 --- a/tests/model/test_base.py +++ b/tests/model/test_base.py @@ -30,14 +30,12 @@ INFER_ARGS = { } -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_base(): model = load_infer_model(**INFER_ARGS) ref_model = load_reference_model(TINY_LLAMA3) compare_model(model, ref_model) -@pytest.mark.runs_on(["cpu"]) @pytest.mark.usefixtures("fix_valuehead_cpu_loading") def test_valuehead(): model = load_infer_model(add_valuehead=True, **INFER_ARGS) diff --git a/tests/model/test_freeze.py b/tests/model/test_freeze.py index 46054101f..b82ec88d5 100644 --- a/tests/model/test_freeze.py +++ b/tests/model/test_freeze.py @@ -14,7 +14,6 @@ import os -import pytest import torch from llamafactory.train.test_utils import load_infer_model, load_train_model @@ -44,7 +43,6 @@ INFER_ARGS = { } -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_freeze_train_all_modules(): model = load_train_model(freeze_trainable_layers=1, **TRAIN_ARGS) for name, param in model.named_parameters(): @@ -56,7 +54,6 @@ def test_freeze_train_all_modules(): assert param.dtype == torch.float16 -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_freeze_train_extra_modules(): model = load_train_model(freeze_trainable_layers=1, freeze_extra_modules="embed_tokens,lm_head", **TRAIN_ARGS) for name, param in model.named_parameters(): @@ -68,7 +65,6 @@ def test_freeze_train_extra_modules(): assert param.dtype == torch.float16 -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_freeze_inference(): model = load_infer_model(**INFER_ARGS) for param in model.parameters(): diff --git a/tests/model/test_full.py b/tests/model/test_full.py index 9912661d9..9058b6acf 100644 --- a/tests/model/test_full.py +++ b/tests/model/test_full.py @@ -14,7 +14,6 @@ import os -import pytest import torch from llamafactory.train.test_utils import load_infer_model, load_train_model @@ -44,7 +43,6 @@ INFER_ARGS = { } -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_full_train(): model = load_train_model(**TRAIN_ARGS) for param in model.parameters(): @@ -52,7 +50,6 @@ def test_full_train(): assert param.dtype == torch.float32 -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_full_inference(): model = load_infer_model(**INFER_ARGS) for param in model.parameters(): diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py index d7739d177..38b6b505d 100644 --- a/tests/model/test_lora.py +++ b/tests/model/test_lora.py @@ -55,35 +55,30 @@ INFER_ARGS = { } -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_lora_train_qv_modules(): model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS) linear_modules, _ = check_lora_model(model) assert linear_modules == {"q_proj", "v_proj"} -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_lora_train_all_modules(): model = load_train_model(lora_target="all", **TRAIN_ARGS) linear_modules, _ = check_lora_model(model) assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"} -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_lora_train_extra_modules(): model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS) _, extra_modules = check_lora_model(model) assert extra_modules == {"embed_tokens", "lm_head"} -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_lora_train_old_adapters(): model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS) ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True) compare_model(model, ref_model) -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_lora_train_new_adapters(): model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS) ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True) @@ -92,7 +87,6 @@ def test_lora_train_new_adapters(): ) -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.usefixtures("fix_valuehead_cpu_loading") def test_lora_train_valuehead(): model = load_train_model(add_valuehead=True, **TRAIN_ARGS) @@ -103,7 +97,6 @@ def test_lora_train_valuehead(): assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"]) -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) def test_lora_inference(): model = load_infer_model(**INFER_ARGS) ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload() diff --git a/tests/model/test_pissa.py b/tests/model/test_pissa.py index 331c7adb0..3b6101f84 100644 --- a/tests/model/test_pissa.py +++ b/tests/model/test_pissa.py @@ -49,7 +49,6 @@ INFER_ARGS = { } -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.") def test_pissa_train(): model = load_train_model(**TRAIN_ARGS) @@ -57,7 +56,6 @@ def test_pissa_train(): compare_model(model, ref_model) -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.xfail(reason="Known connection error.") def test_pissa_inference(): model = load_infer_model(**INFER_ARGS) diff --git a/tests/train/test_sft_trainer.py b/tests/train/test_sft_trainer.py index 3e68ba38f..9f6ebe418 100644 --- a/tests/train/test_sft_trainer.py +++ b/tests/train/test_sft_trainer.py @@ -59,7 +59,6 @@ class DataCollatorWithVerbose(DataCollatorWithPadding): return {k: v[:, :1] for k, v in batch.items()} # truncate input length -@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.parametrize("disable_shuffling", [False, True]) def test_shuffle(disable_shuffling: bool): model_args, data_args, training_args, finetuning_args, _ = get_train_args( diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index 6ff924207..000000000 --- a/tests/utils.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - - -runs_on = pytest.mark.runs_on diff --git a/tests/version.txt b/tests/version.txt index a9b60f523..08eb069fe 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.4.104 +0.9.4.105 diff --git a/tests_v1/accelerator/test_allreduce.py b/tests_v1/accelerator/test_allreduce.py deleted file mode 100644 index 8d0c733e0..000000000 --- a/tests_v1/accelerator/test_allreduce.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -from llamafactory.v1.accelerator.helper import ReduceOp, all_reduce, is_torch_cuda_available, is_torch_npu_available -from llamafactory.v1.utils.utils import find_available_port - - -def _dist_worker(rank, world_size): - if is_torch_cuda_available(): - backend = "nccl" - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(rank) - elif is_torch_npu_available(): - backend = "hccl" - device = torch.device(f"npu:{rank}") - torch.npu.set_device(rank) - else: - backend = "gloo" - device = torch.device("cpu") - - dist.init_process_group( - backend=backend, - rank=rank, - world_size=world_size, - ) - - # -------------------- - # Test all_reduce SUM - # -------------------- - y = torch.tensor(rank + 1.0, device=device) - y_sum = all_reduce(y.clone(), op=ReduceOp.SUM) - assert y_sum.item() == 3.0 - - # -------------------- - # Test all_reduce MEAN - # -------------------- - y_mean = all_reduce(y.clone(), op=ReduceOp.MEAN) - assert y_mean.item() == pytest.approx(1.5) - - # -------------------- - # Test all_reduce MAX - # -------------------- - y_max = all_reduce(y.clone(), op=ReduceOp.MAX) - assert y_max.item() == 2.0 - - dist.destroy_process_group() - - -@pytest.mark.runs_on(["npu", "cuda"]) -@pytest.mark.require_distributed(2) -def test_distributed_ops(monkeypatch): - monkeypatch.setenv("MASTER_ADDR", "127.0.0.1") - monkeypatch.setenv("MASTER_PORT", str(find_available_port())) - WORLD_SIZE = 2 - mp.spawn( - _dist_worker, - args=(WORLD_SIZE,), - nprocs=WORLD_SIZE, - join=True, - ) - - -@pytest.mark.runs_on(["npu", "cuda"]) -@pytest.mark.require_distributed(4) -def test_required_multi(): - # test require_distributed mark ok - pass - - -@pytest.mark.runs_on(["npu", "cuda"]) -@pytest.mark.require_distributed(999) -def test_required_invalid(): - # test require_distributed mark not ok, - raise RuntimeError( - "this case should not be run, please check whether the require_distributed mark implementation is correct" - ) diff --git a/tests_v1/accelerator/test_interface.py b/tests_v1/accelerator/test_interface.py index be2485f92..8bce16b56 100644 --- a/tests_v1/accelerator/test_interface.py +++ b/tests_v1/accelerator/test_interface.py @@ -12,15 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os +import pytest +import torch.multiprocessing as mp + +from llamafactory.v1.accelerator.helper import ReduceOp from llamafactory.v1.accelerator.interface import DistributedInterface +from llamafactory.v1.utils.env import find_available_port +from llamafactory.v1.utils.pytest import dist_env -def test_distributed_interface(): - DistributedInterface() - assert DistributedInterface.get_rank() == int(os.getenv("RANK", "0")) - assert DistributedInterface.get_world_size() == int(os.getenv("WORLD_SIZE", "1")) - assert DistributedInterface.get_local_rank() == int(os.getenv("LOCAL_RANK", "0")) - assert DistributedInterface.get_local_world_size() == int(os.getenv("LOCAL_WORLD_SIZE", "1")) +def _all_reduce_tests(local_rank: int, world_size: int, master_port: int): + with dist_env(local_rank, world_size, master_port): + rank = DistributedInterface().get_rank() + world_size = DistributedInterface().get_world_size() + assert world_size == 2 + + y_sum = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.SUM) + assert y_sum == pytest.approx(3.0) + + y_mean = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.MEAN) + assert y_mean == pytest.approx(1.5) + + y_max = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.MAX) + assert y_max == pytest.approx(2.0) + + z = DistributedInterface().all_gather(rank + 1.0) + assert z == pytest.approx([1.0, 2.0]) + + z = DistributedInterface().broadcast(rank + 1.0) + assert z == pytest.approx(1.0) + + +def test_all_device(): + assert DistributedInterface().get_rank() == int(os.getenv("RANK", "0")) + assert DistributedInterface().get_world_size() == int(os.getenv("WORLD_SIZE", "1")) + assert DistributedInterface().get_local_rank() == int(os.getenv("LOCAL_RANK", "0")) + assert DistributedInterface().get_local_world_size() == int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + +@pytest.mark.runs_on(["cuda", "npu"]) +@pytest.mark.require_distributed(2) +def test_multi_device(): + master_port = find_available_port() + mp.spawn(_all_reduce_tests, args=(2, master_port), nprocs=2) diff --git a/tests_v1/conftest.py b/tests_v1/conftest.py index 897809543..69d40fa5f 100644 --- a/tests_v1/conftest.py +++ b/tests_v1/conftest.py @@ -18,20 +18,17 @@ Contains shared fixtures, pytest configuration, and custom markers. """ import os +from typing import Optional import pytest -from pytest import Config, Item +from pytest import Config, FixtureRequest, Item, MonkeyPatch -from llamafactory.train.test_utils import patch_valuehead_model -from llamafactory.v1.accelerator.helper import get_current_device, get_device_count +from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count +from llamafactory.v1.utils.env import is_env_enabled from llamafactory.v1.utils.packages import is_transformers_version_greater_than -from llamafactory.v1.utils.utils import is_env_enabled -try: - CURRENT_DEVICE = get_current_device().type # cpu | cuda | npu -except Exception: - CURRENT_DEVICE = "cpu" +CURRENT_DEVICE = get_current_accelerator().type def pytest_configure(config: Config): @@ -67,26 +64,27 @@ def _handle_runs_on(items: list[Item]): def _handle_slow_tests(items: list[Item]): """Skip slow tests unless RUN_SLOW is enabled.""" - if not is_env_enabled("RUN_SLOW", "0"): + if not is_env_enabled("RUN_SLOW"): skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)") for item in items: if "slow" in item.keywords: item.add_marker(skip_slow) -def _get_visible_devices_env(): +def _get_visible_devices_env() -> Optional[str]: """Return device visibility env var name.""" if CURRENT_DEVICE == "cuda": return "CUDA_VISIBLE_DEVICES" - if CURRENT_DEVICE == "npu": + elif CURRENT_DEVICE == "npu": return "ASCEND_RT_VISIBLE_DEVICES" - return None + else: + return None def _handle_device_visibility(items: list[Item]): """Handle device visibility based on test markers.""" env_key = _get_visible_devices_env() - if env_key is None or CURRENT_DEVICE == "cpu": + if env_key is None or CURRENT_DEVICE in ("cpu", "mps"): return # Parse visible devices @@ -122,7 +120,7 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]): @pytest.fixture(autouse=True) -def _manage_distributed_env(request, monkeypatch): +def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None: """Set environment variables for distributed tests if specific devices are requested.""" env_key = _get_visible_devices_env() if not env_key: @@ -132,8 +130,7 @@ def _manage_distributed_env(request, monkeypatch): old_value = os.environ.get(env_key) marker = request.node.get_closest_marker("require_distributed") - if marker: - # Distributed test + if marker: # distributed test required = marker.args[0] if marker.args else 2 specific_devices = marker.args[1] if len(marker.args) > 1 else None @@ -143,16 +140,9 @@ def _manage_distributed_env(request, monkeypatch): devices_str = ",".join(str(i) for i in range(required)) monkeypatch.setenv(env_key, devices_str) - else: - # Non-distributed test + else: # non-distributed test if old_value: visible_devices = [v for v in old_value.split(",") if v != ""] monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") else: monkeypatch.setenv(env_key, "0") - - -@pytest.fixture -def fix_valuehead_cpu_loading(): - """Fix valuehead model loading.""" - patch_valuehead_model() diff --git a/tests_v1/core/test_data_loader.py b/tests_v1/core/test_data_loader.py index cf25aa59e..329098242 100644 --- a/tests_v1/core/test_data_loader.py +++ b/tests_v1/core/test_data_loader.py @@ -28,10 +28,10 @@ from transformers import AutoTokenizer from llamafactory.v1.config.data_args import DataArguments from llamafactory.v1.core.data_engine import DataEngine -from llamafactory.v1.core.data_loader import DataLoader from llamafactory.v1.core.trainer_utils.data_collator import ( DefaultCollator, ) +from llamafactory.v1.core.trainer_utils.data_loader import DataLoader from llamafactory.v1.plugins.data_plugins.template import QwenTemplate from llamafactory.v1.utils.batching_queue import TextBatchingQueue diff --git a/tests_v1/plugins/model_plugins/test_kernel_plugin.py b/tests_v1/plugins/model_plugins/test_kernel_plugin.py index 20c61b29e..04f99a757 100644 --- a/tests_v1/plugins/model_plugins/test_kernel_plugin.py +++ b/tests_v1/plugins/model_plugins/test_kernel_plugin.py @@ -12,57 +12,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from unittest.mock import MagicMock, patch +import pytest from transformers import AutoModelForCausalLM from llamafactory.v1.accelerator.helper import get_current_accelerator +from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu +from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel +from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm +from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope -class TestKernelPlugin(unittest.TestCase): - @patch("torch.accelerator.current_accelerator") - def test_apply_kernel(self, mock_get_accelerator): - get_current_accelerator.cache_clear() - mock_device = MagicMock() - mock_device.type = "npu" - mock_get_accelerator.return_value = mock_device - - model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") - - original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward - original_swiglu_forward = model.model.layers[0].mlp.forward - - from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu - from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel - from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm - from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope - - apply_kernel(model, npu_rope.NpuRoPEKernel) - - model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel) - assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward - - model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel) - assert model.model.layers[0].mlp.forward is not original_swiglu_forward +@pytest.fixture(autouse=True) +def clear_accelerator_cache(): + get_current_accelerator.cache_clear() -class Test_Use_V1_Kernels(unittest.TestCase): - @patch("torch.accelerator.current_accelerator") - def test_use_v1_kernels(self, mock_get_accelerator): - get_current_accelerator.cache_clear() - mock_device = MagicMock() - mock_device.type = "npu" - mock_get_accelerator.return_value = mock_device +@patch("torch.accelerator.current_accelerator") +def test_apply_kernel(mock_get_accelerator: MagicMock): + mock_device = MagicMock() + setattr(mock_device, "type", "npu") + mock_get_accelerator.return_value = mock_device - model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") - original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward - original_swiglu_forward = model.model.layers[0].mlp.forward + original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward + original_swiglu_forward = model.model.layers[0].mlp.forward - from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels + apply_kernel(model, npu_rope.NpuRoPEKernel) - model = apply_available_kernels(model) + model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel) + assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward - assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward - assert model.model.layers[0].mlp.forward is not original_swiglu_forward + model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel) + assert model.model.layers[0].mlp.forward is not original_swiglu_forward + + +@patch("torch.accelerator.current_accelerator") +def test_apply_all_kernels(mock_get_accelerator: MagicMock): + get_current_accelerator.cache_clear() + mock_device = MagicMock() + setattr(mock_device, "type", "npu") + mock_get_accelerator.return_value = mock_device + + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") + + original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward + original_swiglu_forward = model.model.layers[0].mlp.forward + + model = apply_available_kernels(model) + + assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward + assert model.model.layers[0].mlp.forward is not original_swiglu_forward