[misc] fix accelerator (#9661)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yaowei Zheng
2025-12-25 02:11:04 +08:00
committed by GitHub
parent 6a2eafbae3
commit a754604c11
44 changed files with 396 additions and 448 deletions

View File

@@ -94,9 +94,8 @@ def configure_quantization(
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if ( if quant_method not in (QuantizationMethod.MXFP4, QuantizationMethod.FP8) and (
quant_method not in (QuantizationMethod.MXFP4 and QuantizationMethod.FP8) is_deepspeed_zero3_enabled() or is_fsdp_enabled()
and (is_deepspeed_zero3_enabled() or is_fsdp_enabled())
): ):
# mxfp4 will dequant the model weights # mxfp4 will dequant the model weights
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")

View File

@@ -15,11 +15,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Utility functions used by the distributed interface.
Including:
- Environment info (rank, world_size, local_rank, etc.)
- Accelerator info (device type, device count, etc.)
- Collective communication operations (all_gather, all_reduce, broadcast)
- Synchronize processes and ensure main-process-first execution order
"""
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum, unique from enum import Enum, unique
from functools import lru_cache from functools import lru_cache, wraps
from typing import Optional from typing import Callable, Optional
import numpy as np import numpy as np
import torch import torch
@@ -46,6 +55,22 @@ class ReduceOp(str, Enum):
MIN = "min" MIN = "min"
def requires_accelerator(fn):
"""Decorator to check if torch.accelerator is available.
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
"""
@wraps(fn)
def wrapper(*args, **kwargs):
if not hasattr(torch, "accelerator"):
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
return fn(*args, **kwargs)
return wrapper
def is_distributed() -> bool: def is_distributed() -> bool:
"""Check if distributed environment is available.""" """Check if distributed environment is available."""
return os.getenv("RANK") is not None return os.getenv("RANK") is not None
@@ -72,105 +97,105 @@ def get_local_world_size() -> int:
@lru_cache @lru_cache
@requires_accelerator
def get_current_accelerator(check_available: bool = True) -> torch.device: def get_current_accelerator(check_available: bool = True) -> torch.device:
"""Get current accelerator. """Get current accelerator."""
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
"""
if not hasattr(torch, "accelerator"):
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
accelerator = torch.accelerator.current_accelerator(check_available=check_available) accelerator = torch.accelerator.current_accelerator(check_available=check_available)
if accelerator is None: return accelerator or torch.device(DeviceType.CPU.value)
return torch.device(DeviceType.CPU.value)
return accelerator
@lru_cache
@requires_accelerator
def get_device_count() -> int:
"""Get the number of available devices."""
return torch.accelerator.device_count()
@requires_accelerator
def synchronize() -> None:
"""Synchronize all processes."""
torch.accelerator.synchronize()
@requires_accelerator
def set_device() -> None:
"""Set current accelerator."""
torch.accelerator.set_device_index(get_local_rank())
def is_torch_cuda_available(): def is_torch_cuda_available():
"""Check if CUDA is available."""
return get_current_accelerator().type == DeviceType.CUDA return get_current_accelerator().type == DeviceType.CUDA
def is_torch_mps_available(): def is_torch_mps_available():
"""Check if MPS is available."""
return get_current_accelerator().type == DeviceType.MPS return get_current_accelerator().type == DeviceType.MPS
def is_torch_npu_available(): def is_torch_npu_available():
"""Check if NPU is available."""
return get_current_accelerator().type == DeviceType.NPU return get_current_accelerator().type == DeviceType.NPU
def is_torch_xpu_available(): def is_torch_xpu_available():
"""Check if XPU is available."""
return get_current_accelerator().type == DeviceType.XPU return get_current_accelerator().type == DeviceType.XPU
def get_current_device() -> "torch.device": def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs) -> TensorLike:
r"""Get the current available device.""" """Operate tensorlike data on current accelerator."""
if is_torch_xpu_available(): device = get_current_accelerator()
device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0")) is_tensor = isinstance(data, torch.Tensor)
elif is_torch_npu_available(): is_ndarray = isinstance(data, np.ndarray)
device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_mps_available(): if is_tensor:
device = "mps:{}".format(os.getenv("LOCAL_RANK", "0")) orig_device = data.device
elif is_torch_cuda_available(): data = data.to(device=device)
device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0")) elif is_ndarray:
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
else: else:
device = "cpu" data = torch.tensor(data, dtype=torch.float, device=device)
return torch.device(device) result = fn(data, **kwargs)
if is_tensor:
def get_device_count() -> int: return result.to(orig_device)
r"""Get the number of available devices.""" elif is_ndarray:
if is_torch_xpu_available(): return result.cpu().numpy()
return torch.xpu.device_count() elif result.numel() == 1:
elif is_torch_npu_available(): return result.item()
return torch.npu.device_count()
elif is_torch_mps_available():
return torch.mps.device_count()
elif is_torch_cuda_available():
return torch.cuda.device_count()
else: else:
return 0 return result.tolist()
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
"""Gathers the tensor from all ranks and concats them along the first dim.""" """Gathers the tensor from all ranks and stacks them at the first dim."""
world_size = get_world_size() world_size = get_world_size()
device = get_current_accelerator() output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=tensor.device)
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device)
dist.all_gather_into_tensor(output_tensor, tensor, group=group) dist.all_gather_into_tensor(output_tensor, tensor, group=group)
return output_tensor.view(-1, *tensor.size()[1:]) return output_tensor.view(-1, *tensor.size())
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike: def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> Tensor:
"""Performs all reduce in the given process group.""" """Performs all reduce in the given process group."""
device = get_current_accelerator()
is_ndarray = isinstance(data, np.ndarray)
is_tensor = isinstance(data, torch.Tensor)
if is_ndarray:
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
elif not is_tensor:
data = torch.tensor(data, dtype=torch.float, device=device)
reduce_ops = { reduce_ops = {
ReduceOp.MEAN: dist.ReduceOp.SUM, ReduceOp.MEAN: dist.ReduceOp.SUM,
ReduceOp.SUM: dist.ReduceOp.SUM, ReduceOp.SUM: dist.ReduceOp.SUM,
ReduceOp.MAX: dist.ReduceOp.MAX, ReduceOp.MAX: dist.ReduceOp.MAX,
ReduceOp.MIN: dist.ReduceOp.MIN, ReduceOp.MIN: dist.ReduceOp.MIN,
} }
dist.all_reduce(data, op=reduce_ops[op], group=group) dist.all_reduce(tensor, op=reduce_ops[op], group=group)
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
data /= dist.get_world_size(group=group) tensor /= dist.get_world_size(group=group)
if is_tensor: return tensor
return data
elif is_ndarray:
return data.cpu().numpy() def broadcast(tensor: Tensor, src: int = 0, group: Optional[ProcessGroup] = None) -> Tensor:
elif data.numel() == 1: """Broadcasts the tensor from the src process to all other processes."""
return data.item() dist.broadcast(tensor, src=src, group=group)
else: return tensor
return data.tolist()
@contextmanager @contextmanager

View File

@@ -15,26 +15,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""A unified interface for model parallelism and data parallelism.
Supports model parallelism types:
- mp_replicate: Replicate model across multiple devices.
- mp_shard: Shard model across multiple devices.
And data parallelism types:
- dp: Data parallelism.
- cp: Context parallelism.
"""
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
from torch.distributed import init_process_group from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from .helper import ( from . import helper
ReduceOp,
all_gather,
all_reduce,
get_current_accelerator,
get_local_rank,
get_local_world_size,
get_rank,
get_world_size,
is_distributed,
)
class Dim(str, Enum): class Dim(str, Enum):
@@ -60,24 +61,24 @@ class DistributedStrategy:
"""Context parallel size, default to 1.""" """Context parallel size, default to 1."""
def __post_init__(self) -> None: def __post_init__(self) -> None:
if not is_distributed(): if not helper.is_distributed():
self.mp_shard_size = 1 self.mp_shard_size = 1
elif self.mp_shard_size is None: elif self.mp_shard_size is None:
self.mp_shard_size = get_world_size() // self.mp_replicate_size self.mp_shard_size = helper.get_world_size() // self.mp_replicate_size
elif self.mp_replicate_size * self.mp_shard_size != get_world_size(): elif self.mp_replicate_size * self.mp_shard_size != helper.get_world_size():
raise ValueError( raise ValueError(
f"mp_replicate_size * mp_shard_size must equal to world_size, " f"mp_replicate_size * mp_shard_size must equal to world_size, "
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {get_world_size()}." f"got {self.mp_replicate_size} * {self.mp_shard_size} != {helper.get_world_size()}."
) )
if not is_distributed(): if not helper.is_distributed():
self.dp_size = 1 self.dp_size = 1
elif self.dp_size is None: elif self.dp_size is None:
self.dp_size = get_world_size() // self.cp_size self.dp_size = helper.get_world_size() // self.cp_size
elif self.dp_size * self.cp_size != get_world_size(): elif self.dp_size * self.cp_size != helper.get_world_size():
raise ValueError( raise ValueError(
f"dp_size * cp_size must equal to world_size, " f"dp_size * cp_size must equal to world_size, "
f"got {self.dp_size} * {self.cp_size} != {get_world_size()}." f"got {self.dp_size} * {self.cp_size} != {helper.get_world_size()}."
) )
@property @property
@@ -106,20 +107,6 @@ class DistributedInterface:
_instance: Optional["DistributedInterface"] = None _instance: Optional["DistributedInterface"] = None
_initialized: bool = False _initialized: bool = False
_is_distributed = is_distributed()
_rank = get_rank()
_world_size = get_world_size()
_local_rank = get_local_rank()
_local_world_size = get_local_world_size()
strategy: Optional[DistributedStrategy] = None
"""Distributed strategy."""
model_device_mesh: Optional[DeviceMesh] = None
"""Model parallel device mesh."""
data_device_mesh: Optional[DeviceMesh] = None
"""Data parallel device mesh."""
current_accelerator = get_current_accelerator()
"""Current accelerator."""
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface": def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
"""Singleton pattern.""" """Singleton pattern."""
@@ -132,6 +119,14 @@ class DistributedInterface:
if self._initialized: if self._initialized:
return return
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()
self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size()
self.current_accelerator = helper.get_current_accelerator()
self.device_count = helper.get_device_count()
if config is None: if config is None:
self.strategy = DistributedStrategy() self.strategy = DistributedStrategy()
timeout = 18000 timeout = 18000
@@ -145,6 +140,7 @@ class DistributedInterface:
timeout = config.get("timeout", 18000) timeout = config.get("timeout", 18000)
if self._is_distributed: if self._is_distributed:
helper.set_device()
init_process_group(timeout=timedelta(seconds=timeout)) init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh( self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type, device_type=self.current_accelerator.type,
@@ -169,65 +165,84 @@ class DistributedInterface:
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}" f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
) )
@classmethod def get_device_mesh(self, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
def get_device_mesh(cls, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
"""Get device mesh for specified dimension.""" """Get device mesh for specified dimension."""
if dim is None: if dim is None:
raise ValueError("dim must be specified.") raise ValueError("dim must be specified.")
elif cls.model_device_mesh is None: elif self.model_device_mesh is None:
return None return None
elif dim in cls.strategy.data_mesh_dim_names: elif dim in self.strategy.data_mesh_dim_names:
return cls.data_device_mesh[dim.value] return self.data_device_mesh[dim.value]
else: else:
return cls.model_device_mesh[dim.value] return self.model_device_mesh[dim.value]
@classmethod def get_group(self, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension.""" """Get process group for specified dimension."""
if cls.model_device_mesh is None or dim is None: if self.model_device_mesh is None or dim is None:
return None return None
else: else:
return cls.get_device_mesh(dim).get_group() return self.get_device_mesh(dim).get_group()
@classmethod def get_rank(self, dim: Optional[Dim] = None) -> int:
def get_rank(cls, dim: Optional[Dim] = None) -> int:
"""Get parallel rank for specified dimension.""" """Get parallel rank for specified dimension."""
if cls.model_device_mesh is None: if self.model_device_mesh is None:
return 0 return 0
elif dim is None: elif dim is None:
return cls._rank return self._rank
else: else:
return cls.get_device_mesh(dim).get_local_rank() return self.get_device_mesh(dim).get_local_rank()
@classmethod def get_world_size(self, dim: Optional[Dim] = None) -> int:
def get_world_size(cls, dim: Optional[Dim] = None) -> int:
"""Get parallel size for specified dimension.""" """Get parallel size for specified dimension."""
if cls.model_device_mesh is None: if self.model_device_mesh is None:
return 1 return 1
elif dim is None: elif dim is None:
return cls._world_size return self._world_size
else: else:
return cls.get_device_mesh(dim).size() return self.get_device_mesh(dim).size()
@classmethod def get_local_rank(self) -> int:
def get_local_rank(cls) -> int:
"""Get parallel local rank.""" """Get parallel local rank."""
return cls._local_rank return self._local_rank
@classmethod def get_local_world_size(self) -> int:
def get_local_world_size(cls) -> int:
"""Get parallel local world size.""" """Get parallel local world size."""
return cls._local_world_size return self._local_world_size
@classmethod def all_gather(self, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
def all_gather(cls, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
"""Gather tensor across specified parallel group.""" """Gather tensor across specified parallel group."""
return all_gather(data, cls.get_group(dim)) if cls.model_device_mesh is not None else data if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
else:
return data
@classmethod def all_reduce(
def all_reduce(cls, data: TensorLike, op: ReduceOp = ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP) -> TensorLike: self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP
) -> TensorLike:
"""Reduce tensor across specified parallel group.""" """Reduce tensor across specified parallel group."""
return all_reduce(data, op, cls.get_group(dim)) if cls.model_device_mesh is not None else data if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
else:
return data
def broadcast(self, data: TensorLike, src: int = 0, dim: Optional[Dim] = Dim.DP) -> TensorLike:
"""Broadcast tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
else:
return data
def sync(self) -> None:
"""Synchronize all processes."""
helper.synchronize()
def barrier(self) -> None:
"""Barrier all processes."""
barrier()
def destroy(self) -> None:
"""Destroy all processes."""
destroy_process_group()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -97,7 +97,7 @@ class ModelLoader:
self.args.model, self.args.model,
config=self.model_config, config=self.model_config,
dtype="auto", dtype="auto",
device_map=DistributedInterface.current_accelerator, device_map=DistributedInterface().current_accelerator,
trust_remote_code=self.args.trust_remote_code, trust_remote_code=self.args.trust_remote_code,
) )

View File

@@ -22,10 +22,10 @@ from typing import Optional
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ..utils.batching_queue import BaseBatchingQueue from ...utils.batching_queue import BaseBatchingQueue
from ..utils.logging import get_logger from ...utils.logging import get_logger
from ..utils.types import Processor, TorchDataset from ...utils.types import Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator from .data_collator import DataCollator
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -17,7 +17,7 @@ import socket
def find_available_port() -> int: def find_available_port() -> int:
r"""Find an available port on the local machine.""" """Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0)) sock.bind(("", 0))
port = sock.getsockname()[1] port = sock.getsockname()[1]
@@ -26,9 +26,5 @@ def find_available_port() -> int:
def is_env_enabled(env_var: str, default: str = "0") -> bool: def is_env_enabled(env_var: str, default: str = "0") -> bool:
r"""Check if the environment variable is enabled.""" """Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "y", "1"] return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"]
if __name__ == "__main__":
print(find_available_port())

View File

@@ -0,0 +1,35 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
@contextmanager
def dist_env(local_rank: int = 0, world_size: int = 1, master_port: int = 25595):
"""Set distributed environment variables."""
env_vars = {
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": str(master_port),
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"LOCAL_WORLD_SIZE": str(world_size),
}
os.environ.update(env_vars)
try:
yield
finally:
for key in env_vars.keys():
os.environ.pop(key, None)

View File

@@ -18,19 +18,17 @@ Contains shared fixtures, pytest configuration, and custom markers.
""" """
import os import os
from typing import Optional
import pytest import pytest
from pytest import Config, Item from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled
from llamafactory.extras.packages import is_transformers_version_greater_than from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import patch_valuehead_model from llamafactory.train.test_utils import patch_valuehead_model
try: CURRENT_DEVICE = get_current_device().type
CURRENT_DEVICE = get_current_device().type # cpu | cuda | npu
except Exception:
CURRENT_DEVICE = "cpu"
def pytest_configure(config: Config): def pytest_configure(config: Config):
@@ -66,26 +64,27 @@ def _handle_runs_on(items: list[Item]):
def _handle_slow_tests(items: list[Item]): def _handle_slow_tests(items: list[Item]):
"""Skip slow tests unless RUN_SLOW is enabled.""" """Skip slow tests unless RUN_SLOW is enabled."""
if not is_env_enabled("RUN_SLOW", "0"): if not is_env_enabled("RUN_SLOW"):
skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)") skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)")
for item in items: for item in items:
if "slow" in item.keywords: if "slow" in item.keywords:
item.add_marker(skip_slow) item.add_marker(skip_slow)
def _get_visible_devices_env(): def _get_visible_devices_env() -> Optional[str]:
"""Return device visibility env var name.""" """Return device visibility env var name."""
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES" return "CUDA_VISIBLE_DEVICES"
if CURRENT_DEVICE == "npu": elif CURRENT_DEVICE == "npu":
return "ASCEND_RT_VISIBLE_DEVICES" return "ASCEND_RT_VISIBLE_DEVICES"
return None else:
return None
def _handle_device_visibility(items: list[Item]): def _handle_device_visibility(items: list[Item]):
"""Handle device visibility based on test markers.""" """Handle device visibility based on test markers."""
env_key = _get_visible_devices_env() env_key = _get_visible_devices_env()
if env_key is None or CURRENT_DEVICE == "cpu": if env_key is None or CURRENT_DEVICE in ("cpu", "mps"):
return return
# Parse visible devices # Parse visible devices
@@ -121,7 +120,7 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _manage_distributed_env(request, monkeypatch): def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested.""" """Set environment variables for distributed tests if specific devices are requested."""
env_key = _get_visible_devices_env() env_key = _get_visible_devices_env()
if not env_key: if not env_key:
@@ -131,8 +130,7 @@ def _manage_distributed_env(request, monkeypatch):
old_value = os.environ.get(env_key) old_value = os.environ.get(env_key)
marker = request.node.get_closest_marker("require_distributed") marker = request.node.get_closest_marker("require_distributed")
if marker: if marker: # distributed test
# Distributed test
required = marker.args[0] if marker.args else 2 required = marker.args[0] if marker.args else 2
specific_devices = marker.args[1] if len(marker.args) > 1 else None specific_devices = marker.args[1] if len(marker.args) > 1 else None
@@ -142,8 +140,7 @@ def _manage_distributed_env(request, monkeypatch):
devices_str = ",".join(str(i) for i in range(required)) devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str) monkeypatch.setenv(env_key, devices_str)
else: else: # non-distributed test
# Non-distributed test
if old_value: if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""] visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")

View File

@@ -42,7 +42,7 @@ TRAIN_ARGS = {
} }
@pytest.mark.runs_on(["cpu", "npu", "cuda"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16]) @pytest.mark.parametrize("num_samples", [16])
def test_feedback_data(num_samples: int): def test_feedback_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"] train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]

View File

@@ -51,7 +51,7 @@ def _convert_sharegpt_to_openai(messages: list[dict[str, str]]) -> list[dict[str
return new_messages return new_messages
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16]) @pytest.mark.parametrize("num_samples", [16])
def test_pairwise_data(num_samples: int): def test_pairwise_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"] train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]

View File

@@ -18,7 +18,7 @@ import pytest
from llamafactory.data.processor.processor_utils import infer_seqlen from llamafactory.data.processor.processor_utils import infer_seqlen
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_input,test_output", "test_input,test_output",
[ [

View File

@@ -42,7 +42,7 @@ TRAIN_ARGS = {
} }
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16]) @pytest.mark.parametrize("num_samples", [16])
def test_supervised_single_turn(num_samples: int): def test_supervised_single_turn(num_samples: int):
train_dataset = load_dataset_module(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)["train_dataset"] train_dataset = load_dataset_module(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)["train_dataset"]
@@ -62,7 +62,7 @@ def test_supervised_single_turn(num_samples: int):
assert train_dataset["input_ids"][index] == ref_input_ids assert train_dataset["input_ids"][index] == ref_input_ids
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [8]) @pytest.mark.parametrize("num_samples", [8])
def test_supervised_multi_turn(num_samples: int): def test_supervised_multi_turn(num_samples: int):
train_dataset = load_dataset_module(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)[ train_dataset = load_dataset_module(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)[
@@ -76,7 +76,7 @@ def test_supervised_multi_turn(num_samples: int):
assert train_dataset["input_ids"][index] == ref_input_ids assert train_dataset["input_ids"][index] == ref_input_ids
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [4]) @pytest.mark.parametrize("num_samples", [4])
def test_supervised_train_on_prompt(num_samples: int): def test_supervised_train_on_prompt(num_samples: int):
train_dataset = load_dataset_module( train_dataset = load_dataset_module(
@@ -91,7 +91,7 @@ def test_supervised_train_on_prompt(num_samples: int):
assert train_dataset["labels"][index] == ref_ids assert train_dataset["labels"][index] == ref_ids
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [4]) @pytest.mark.parametrize("num_samples", [4])
def test_supervised_mask_history(num_samples: int): def test_supervised_mask_history(num_samples: int):
train_dataset = load_dataset_module( train_dataset = load_dataset_module(

View File

@@ -46,7 +46,7 @@ TRAIN_ARGS = {
} }
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16]) @pytest.mark.parametrize("num_samples", [16])
def test_unsupervised_data(num_samples: int): def test_unsupervised_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"] train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]

View File

@@ -29,7 +29,7 @@ from llamafactory.model import load_tokenizer
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3") TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_base_collator(): def test_base_collator():
model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA3, "template": "default"}) model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA3, "template": "default"})
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
@@ -73,7 +73,7 @@ def test_base_collator():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all() assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_multimodal_collator(): def test_multimodal_collator():
model_args, data_args, *_ = get_infer_args( model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"} {"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}

View File

@@ -20,7 +20,7 @@ from llamafactory.data.parser import DatasetAttr
from llamafactory.hparams import DataArguments from llamafactory.hparams import DataArguments
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_alpaca_converter(): def test_alpaca_converter():
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset") dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
data_args = DataArguments() data_args = DataArguments()
@@ -41,7 +41,7 @@ def test_alpaca_converter():
} }
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_sharegpt_converter(): def test_sharegpt_converter():
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset") dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
data_args = DataArguments() data_args = DataArguments()

View File

@@ -38,19 +38,19 @@ TOOLS = [
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_empty_formatter(): def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"]) formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"] assert formatter.apply() == ["\n"]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_string_formatter(): def test_string_formatter():
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"]) formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"] assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_function_formatter(): def test_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default") formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION) tool_calls = json.dumps(FUNCTION)
@@ -60,7 +60,7 @@ def test_function_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_multi_function_formatter(): def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default") formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps([FUNCTION] * 2) tool_calls = json.dumps([FUNCTION] * 2)
@@ -71,7 +71,7 @@ def test_multi_function_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_default_tool_formatter(): def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default") formatter = ToolFormatter(tool_format="default")
assert formatter.apply(content=json.dumps(TOOLS)) == [ assert formatter.apply(content=json.dumps(TOOLS)) == [
@@ -90,14 +90,14 @@ def test_default_tool_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_default_tool_extractor(): def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default") formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}""" result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_default_multi_tool_extractor(): def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default") formatter = ToolFormatter(tool_format="default")
result = ( result = (
@@ -110,14 +110,14 @@ def test_default_multi_tool_extractor():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_function_formatter(): def test_glm4_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4") formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
tool_calls = json.dumps(FUNCTION) tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""] assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_tool_formatter(): def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4") formatter = ToolFormatter(tool_format="glm4")
assert formatter.apply(content=json.dumps(TOOLS)) == [ assert formatter.apply(content=json.dumps(TOOLS)) == [
@@ -128,14 +128,14 @@ def test_glm4_tool_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_tool_extractor(): def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4") formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n""" result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_function_formatter(): def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3") formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps(FUNCTION) tool_calls = json.dumps(FUNCTION)
@@ -144,7 +144,7 @@ def test_llama3_function_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_multi_function_formatter(): def test_llama3_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3") formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps([FUNCTION] * 2) tool_calls = json.dumps([FUNCTION] * 2)
@@ -155,7 +155,7 @@ def test_llama3_multi_function_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_tool_formatter(): def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3") formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y") date = datetime.now().strftime("%d %b %Y")
@@ -169,14 +169,14 @@ def test_llama3_tool_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_tool_extractor(): def test_llama3_tool_extractor():
formatter = ToolFormatter(tool_format="llama3") formatter = ToolFormatter(tool_format="llama3")
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n""" result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_multi_tool_extractor(): def test_llama3_multi_tool_extractor():
formatter = ToolFormatter(tool_format="llama3") formatter = ToolFormatter(tool_format="llama3")
result = ( result = (
@@ -189,7 +189,7 @@ def test_llama3_multi_tool_extractor():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_function_formatter(): def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral") formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION) tool_calls = json.dumps(FUNCTION)
@@ -199,7 +199,7 @@ def test_mistral_function_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_multi_function_formatter(): def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral") formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2) tool_calls = json.dumps([FUNCTION] * 2)
@@ -211,7 +211,7 @@ def test_mistral_multi_function_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_tool_formatter(): def test_mistral_tool_formatter():
formatter = ToolFormatter(tool_format="mistral") formatter = ToolFormatter(tool_format="mistral")
wrapped_tool = {"type": "function", "function": TOOLS[0]} wrapped_tool = {"type": "function", "function": TOOLS[0]}
@@ -220,14 +220,14 @@ def test_mistral_tool_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_tool_extractor(): def test_mistral_tool_extractor():
formatter = ToolFormatter(tool_format="mistral") formatter = ToolFormatter(tool_format="mistral")
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}""" result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_multi_tool_extractor(): def test_mistral_multi_tool_extractor():
formatter = ToolFormatter(tool_format="mistral") formatter = ToolFormatter(tool_format="mistral")
result = ( result = (
@@ -240,7 +240,7 @@ def test_mistral_multi_tool_extractor():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_function_formatter(): def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen") formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION) tool_calls = json.dumps(FUNCTION)
@@ -249,7 +249,7 @@ def test_qwen_function_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_multi_function_formatter(): def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen") formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2) tool_calls = json.dumps([FUNCTION] * 2)
@@ -260,7 +260,7 @@ def test_qwen_multi_function_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_tool_formatter(): def test_qwen_tool_formatter():
formatter = ToolFormatter(tool_format="qwen") formatter = ToolFormatter(tool_format="qwen")
wrapped_tool = {"type": "function", "function": TOOLS[0]} wrapped_tool = {"type": "function", "function": TOOLS[0]}
@@ -274,14 +274,14 @@ def test_qwen_tool_formatter():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_tool_extractor(): def test_qwen_tool_extractor():
formatter = ToolFormatter(tool_format="qwen") formatter = ToolFormatter(tool_format="qwen")
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""" result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_multi_tool_extractor(): def test_qwen_multi_tool_extractor():
formatter = ToolFormatter(tool_format="qwen") formatter = ToolFormatter(tool_format="qwen")
result = ( result = (

View File

@@ -40,21 +40,21 @@ TRAIN_ARGS = {
} }
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_load_train_only(): def test_load_train_only():
dataset_module = load_dataset_module(**TRAIN_ARGS) dataset_module = load_dataset_module(**TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None assert dataset_module.get("train_dataset") is not None
assert dataset_module.get("eval_dataset") is None assert dataset_module.get("eval_dataset") is None
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_load_val_size(): def test_load_val_size():
dataset_module = load_dataset_module(val_size=0.1, **TRAIN_ARGS) dataset_module = load_dataset_module(val_size=0.1, **TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None assert dataset_module.get("train_dataset") is not None
assert dataset_module.get("eval_dataset") is not None assert dataset_module.get("eval_dataset") is not None
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_load_eval_data(): def test_load_eval_data():
dataset_module = load_dataset_module(eval_dataset=TINY_DATA, **TRAIN_ARGS) dataset_module = load_dataset_module(eval_dataset=TINY_DATA, **TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None assert dataset_module.get("train_dataset") is not None

View File

@@ -179,7 +179,7 @@ def _check_plugin(
) )
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_base_plugin(): def test_base_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3) tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3)
base_plugin = get_mm_plugin(name="base") base_plugin = get_mm_plugin(name="base")
@@ -187,7 +187,7 @@ def test_base_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_gemma3_plugin(): def test_gemma3_plugin():
@@ -210,7 +210,7 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_internvl_plugin(): def test_internvl_plugin():
image_seqlen = 256 image_seqlen = 256
@@ -229,7 +229,7 @@ def test_internvl_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0")
def test_llama4_plugin(): def test_llama4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4) tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
@@ -251,7 +251,7 @@ def test_llama4_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_llava_plugin(): def test_llava_plugin():
image_seqlen = 576 image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf") tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
@@ -265,7 +265,7 @@ def test_llava_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_llava_next_plugin(): def test_llava_next_plugin():
image_seqlen = 1176 image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf") tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
@@ -279,7 +279,7 @@ def test_llava_next_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_llava_next_video_plugin(): def test_llava_next_video_plugin():
image_seqlen = 1176 image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf") tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
@@ -293,7 +293,7 @@ def test_llava_next_video_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin(): def test_paligemma_plugin():
image_seqlen = 256 image_seqlen = 256
@@ -313,7 +313,7 @@ def test_paligemma_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_pixtral_plugin(): def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2 image_slice_height, image_slice_width = 2, 2
@@ -336,7 +336,7 @@ def test_pixtral_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_qwen2_omni_plugin(): def test_qwen2_omni_plugin():
image_seqlen, audio_seqlen = 4, 2 image_seqlen, audio_seqlen = 4, 2
@@ -367,7 +367,7 @@ def test_qwen2_omni_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_qwen2_vl_plugin(): def test_qwen2_vl_plugin():
image_seqlen = 4 image_seqlen = 4
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
@@ -384,7 +384,7 @@ def test_qwen2_vl_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
def test_qwen3_vl_plugin(): def test_qwen3_vl_plugin():
frame_seqlen = 1 frame_seqlen = 1
@@ -406,7 +406,7 @@ def test_qwen3_vl_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_video_llava_plugin(): def test_video_llava_plugin():
image_seqlen = 256 image_seqlen = 256

View File

@@ -89,7 +89,7 @@ def _check_template(
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool): def test_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
@@ -105,7 +105,7 @@ def test_encode_oneturn(use_fast: bool):
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool): def test_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
@@ -127,7 +127,7 @@ def test_encode_multiturn(use_fast: bool):
) )
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None]) @pytest.mark.parametrize("enable_thinking", [True, False, None])
@@ -154,7 +154,7 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None]) @pytest.mark.parametrize("enable_thinking", [True, False, None])
@@ -184,7 +184,7 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t
) )
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool): def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
@@ -195,7 +195,7 @@ def test_jinja_template(use_fast: bool):
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES) assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_ollama_modelfile(): def test_ollama_modelfile():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
@@ -213,14 +213,14 @@ def test_ollama_modelfile():
) )
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_get_stop_token_ids(): def test_get_stop_token_ids():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009} assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009}
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool): def test_gemma_template(use_fast: bool):
@@ -234,7 +234,7 @@ def test_gemma_template(use_fast: bool):
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast) _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_gemma2_template(use_fast: bool): def test_gemma2_template(use_fast: bool):
@@ -248,7 +248,7 @@ def test_gemma2_template(use_fast: bool):
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast) _check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool): def test_llama3_template(use_fast: bool):
@@ -262,7 +262,7 @@ def test_llama3_template(use_fast: bool):
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast) _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))] "use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
) )
@@ -284,7 +284,7 @@ def test_llama4_template(use_fast: bool):
pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")), pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
], ],
) )
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_phi4_template(use_fast: bool): def test_phi4_template(use_fast: bool):
prompt_str = ( prompt_str = (
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>" f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
@@ -296,7 +296,7 @@ def test_phi4_template(use_fast: bool):
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast) _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_qwen2_5_template(use_fast: bool): def test_qwen2_5_template(use_fast: bool):
@@ -311,7 +311,7 @@ def test_qwen2_5_template(use_fast: bool):
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast) _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool): def test_qwen3_template(use_fast: bool, cot_messages: bool):
@@ -331,7 +331,7 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool):
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages) _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_parse_llama3_template(): def test_parse_llama3_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
template = parse_template(tokenizer) template = parse_template(tokenizer)
@@ -345,7 +345,7 @@ def test_parse_llama3_template():
assert template.default_system == "" assert template.default_system == ""
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen_template(): def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
@@ -358,7 +358,7 @@ def test_parse_qwen_template():
assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen3_template(): def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)

View File

@@ -37,13 +37,13 @@ MESSAGES = [
EXPECTED_RESPONSE = "_rho" EXPECTED_RESPONSE = "_rho"
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_chat(): def test_chat():
chat_model = ChatModel(INFER_ARGS) chat_model = ChatModel(INFER_ARGS)
assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_stream_chat(): def test_stream_chat():
chat_model = ChatModel(INFER_ARGS) chat_model = ChatModel(INFER_ARGS)
response = "" response = ""

View File

@@ -39,7 +39,7 @@ MESSAGES = [
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cuda"])
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed") @pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_chat(): def test_chat():
r"""Test the SGLang engine's basic chat functionality.""" r"""Test the SGLang engine's basic chat functionality."""
@@ -49,7 +49,7 @@ def test_chat():
print(response.response_text) print(response.response_text)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cuda"])
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed") @pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_stream_chat(): def test_stream_chat():
r"""Test the SGLang engine's streaming chat functionality.""" r"""Test the SGLang engine's streaming chat functionality."""

View File

@@ -49,7 +49,7 @@ INFER_ARGS = {
OS_NAME = os.getenv("OS_NAME", "") OS_NAME = os.getenv("OS_NAME", "")
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"stage,dataset", "stage,dataset",
[ [
@@ -66,7 +66,7 @@ def test_run_exp(stage: str, dataset: str):
assert os.path.exists(output_dir) assert os.path.exists(output_dir)
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_export(): def test_export():
export_dir = os.path.join("output", "llama3_export") export_dir = os.path.join("output", "llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS}) export_model({"export_dir": export_dir, **INFER_ARGS})

View File

@@ -17,7 +17,7 @@ import pytest
from llamafactory.eval.template import get_eval_template from llamafactory.eval.template import get_eval_template
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_eval_template_en(): def test_eval_template_en():
support_set = [ support_set = [
{ {
@@ -56,7 +56,7 @@ def test_eval_template_en():
] ]
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_eval_template_zh(): def test_eval_template_zh():
support_set = [ support_set = [
{ {

View File

@@ -25,7 +25,6 @@ TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
UNUSED_TOKEN = "<|UNUSED_TOKEN|>" UNUSED_TOKEN = "<|UNUSED_TOKEN|>"
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("special_tokens", [False, True]) @pytest.mark.parametrize("special_tokens", [False, True])
def test_add_tokens(special_tokens: bool): def test_add_tokens(special_tokens: bool):
if special_tokens: if special_tokens:

View File

@@ -39,7 +39,6 @@ INFER_ARGS = {
} }
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.xfail(is_transformers_version_greater_than("4.48"), reason="Attention refactor.") @pytest.mark.xfail(is_transformers_version_greater_than("4.48"), reason="Attention refactor.")
def test_attention(): def test_attention():
attention_available = ["disabled"] attention_available = ["disabled"]

View File

@@ -39,7 +39,6 @@ TRAIN_ARGS = {
} }
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("disable_gradient_checkpointing", [False, True]) @pytest.mark.parametrize("disable_gradient_checkpointing", [False, True])
def test_vanilla_checkpointing(disable_gradient_checkpointing: bool): def test_vanilla_checkpointing(disable_gradient_checkpointing: bool):
model = load_train_model(disable_gradient_checkpointing=disable_gradient_checkpointing, **TRAIN_ARGS) model = load_train_model(disable_gradient_checkpointing=disable_gradient_checkpointing, **TRAIN_ARGS)
@@ -47,14 +46,12 @@ def test_vanilla_checkpointing(disable_gradient_checkpointing: bool):
assert getattr(module, "gradient_checkpointing") != disable_gradient_checkpointing assert getattr(module, "gradient_checkpointing") != disable_gradient_checkpointing
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_unsloth_gradient_checkpointing(): def test_unsloth_gradient_checkpointing():
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS) model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()): for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing" assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_upcast_layernorm(): def test_upcast_layernorm():
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS) model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
@@ -62,7 +59,6 @@ def test_upcast_layernorm():
assert param.dtype == torch.float32 assert param.dtype == torch.float32
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_upcast_lmhead_output(): def test_upcast_lmhead_output():
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS) model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device()) inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())

View File

@@ -24,7 +24,6 @@ from llamafactory.model.model_utils.misc import find_expanded_modules
HF_TOKEN = os.getenv("HF_TOKEN") HF_TOKEN = os.getenv("HF_TOKEN")
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_expanded_modules(): def test_expanded_modules():
config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

View File

@@ -18,7 +18,6 @@ import torch
from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attention_mask,golden_seq_lens", "attention_mask,golden_seq_lens",
[ [

View File

@@ -23,7 +23,6 @@ from llamafactory.hparams import FinetuningArguments, ModelArguments
from llamafactory.model.adapter import init_adapter from llamafactory.model.adapter import init_adapter
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("freeze_vision_tower", (False, True)) @pytest.mark.parametrize("freeze_vision_tower", (False, True))
@pytest.mark.parametrize("freeze_multi_modal_projector", (False, True)) @pytest.mark.parametrize("freeze_multi_modal_projector", (False, True))
@pytest.mark.parametrize("freeze_language_model", (False, True)) @pytest.mark.parametrize("freeze_language_model", (False, True))
@@ -49,7 +48,6 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
assert param.requires_grad != freeze_language_model assert param.requires_grad != freeze_language_model
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("freeze_vision_tower,freeze_language_model", ((False, False), (False, True), (True, False))) @pytest.mark.parametrize("freeze_vision_tower,freeze_language_model", ((False, False), (False, True), (True, False)))
def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool): def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct") model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
@@ -82,7 +80,6 @@ def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
assert (merger_param_name in trainable_params) is False assert (merger_param_name in trainable_params) is False
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_visual_model_save_load(): def test_visual_model_save_load():
# check VLM's state dict: https://github.com/huggingface/transformers/pull/38385 # check VLM's state dict: https://github.com/huggingface/transformers/pull/38385
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct") model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")

View File

@@ -30,14 +30,12 @@ INFER_ARGS = {
} }
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_base(): def test_base():
model = load_infer_model(**INFER_ARGS) model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA3) ref_model = load_reference_model(TINY_LLAMA3)
compare_model(model, ref_model) compare_model(model, ref_model)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.usefixtures("fix_valuehead_cpu_loading") @pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_valuehead(): def test_valuehead():
model = load_infer_model(add_valuehead=True, **INFER_ARGS) model = load_infer_model(add_valuehead=True, **INFER_ARGS)

View File

@@ -14,7 +14,6 @@
import os import os
import pytest
import torch import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model from llamafactory.train.test_utils import load_infer_model, load_train_model
@@ -44,7 +43,6 @@ INFER_ARGS = {
} }
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_freeze_train_all_modules(): def test_freeze_train_all_modules():
model = load_train_model(freeze_trainable_layers=1, **TRAIN_ARGS) model = load_train_model(freeze_trainable_layers=1, **TRAIN_ARGS)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
@@ -56,7 +54,6 @@ def test_freeze_train_all_modules():
assert param.dtype == torch.float16 assert param.dtype == torch.float16
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_freeze_train_extra_modules(): def test_freeze_train_extra_modules():
model = load_train_model(freeze_trainable_layers=1, freeze_extra_modules="embed_tokens,lm_head", **TRAIN_ARGS) model = load_train_model(freeze_trainable_layers=1, freeze_extra_modules="embed_tokens,lm_head", **TRAIN_ARGS)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
@@ -68,7 +65,6 @@ def test_freeze_train_extra_modules():
assert param.dtype == torch.float16 assert param.dtype == torch.float16
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_freeze_inference(): def test_freeze_inference():
model = load_infer_model(**INFER_ARGS) model = load_infer_model(**INFER_ARGS)
for param in model.parameters(): for param in model.parameters():

View File

@@ -14,7 +14,6 @@
import os import os
import pytest
import torch import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model from llamafactory.train.test_utils import load_infer_model, load_train_model
@@ -44,7 +43,6 @@ INFER_ARGS = {
} }
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_full_train(): def test_full_train():
model = load_train_model(**TRAIN_ARGS) model = load_train_model(**TRAIN_ARGS)
for param in model.parameters(): for param in model.parameters():
@@ -52,7 +50,6 @@ def test_full_train():
assert param.dtype == torch.float32 assert param.dtype == torch.float32
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_full_inference(): def test_full_inference():
model = load_infer_model(**INFER_ARGS) model = load_infer_model(**INFER_ARGS)
for param in model.parameters(): for param in model.parameters():

View File

@@ -55,35 +55,30 @@ INFER_ARGS = {
} }
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_qv_modules(): def test_lora_train_qv_modules():
model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS) model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model) linear_modules, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "v_proj"} assert linear_modules == {"q_proj", "v_proj"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_all_modules(): def test_lora_train_all_modules():
model = load_train_model(lora_target="all", **TRAIN_ARGS) model = load_train_model(lora_target="all", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model) linear_modules, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"} assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_extra_modules(): def test_lora_train_extra_modules():
model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS) model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS)
_, extra_modules = check_lora_model(model) _, extra_modules = check_lora_model(model)
assert extra_modules == {"embed_tokens", "lm_head"} assert extra_modules == {"embed_tokens", "lm_head"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_old_adapters(): def test_lora_train_old_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS) model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True) ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model(model, ref_model) compare_model(model, ref_model)
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_new_adapters(): def test_lora_train_new_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS) model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True) ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
@@ -92,7 +87,6 @@ def test_lora_train_new_adapters():
) )
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.usefixtures("fix_valuehead_cpu_loading") @pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_lora_train_valuehead(): def test_lora_train_valuehead():
model = load_train_model(add_valuehead=True, **TRAIN_ARGS) model = load_train_model(add_valuehead=True, **TRAIN_ARGS)
@@ -103,7 +97,6 @@ def test_lora_train_valuehead():
assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"]) assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"])
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_inference(): def test_lora_inference():
model = load_infer_model(**INFER_ARGS) model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload() ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()

View File

@@ -49,7 +49,6 @@ INFER_ARGS = {
} }
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.") @pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")
def test_pissa_train(): def test_pissa_train():
model = load_train_model(**TRAIN_ARGS) model = load_train_model(**TRAIN_ARGS)
@@ -57,7 +56,6 @@ def test_pissa_train():
compare_model(model, ref_model) compare_model(model, ref_model)
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.xfail(reason="Known connection error.") @pytest.mark.xfail(reason="Known connection error.")
def test_pissa_inference(): def test_pissa_inference():
model = load_infer_model(**INFER_ARGS) model = load_infer_model(**INFER_ARGS)

View File

@@ -59,7 +59,6 @@ class DataCollatorWithVerbose(DataCollatorWithPadding):
return {k: v[:, :1] for k, v in batch.items()} # truncate input length return {k: v[:, :1] for k, v in batch.items()} # truncate input length
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("disable_shuffling", [False, True]) @pytest.mark.parametrize("disable_shuffling", [False, True])
def test_shuffle(disable_shuffling: bool): def test_shuffle(disable_shuffling: bool):
model_args, data_args, training_args, finetuning_args, _ = get_train_args( model_args, data_args, training_args, finetuning_args, _ = get_train_args(

View File

@@ -1,18 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
runs_on = pytest.mark.runs_on

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated # change if test fails or cache is outdated
0.9.4.104 0.9.4.105

View File

@@ -1,93 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from llamafactory.v1.accelerator.helper import ReduceOp, all_reduce, is_torch_cuda_available, is_torch_npu_available
from llamafactory.v1.utils.utils import find_available_port
def _dist_worker(rank, world_size):
if is_torch_cuda_available():
backend = "nccl"
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(rank)
elif is_torch_npu_available():
backend = "hccl"
device = torch.device(f"npu:{rank}")
torch.npu.set_device(rank)
else:
backend = "gloo"
device = torch.device("cpu")
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
)
# --------------------
# Test all_reduce SUM
# --------------------
y = torch.tensor(rank + 1.0, device=device)
y_sum = all_reduce(y.clone(), op=ReduceOp.SUM)
assert y_sum.item() == 3.0
# --------------------
# Test all_reduce MEAN
# --------------------
y_mean = all_reduce(y.clone(), op=ReduceOp.MEAN)
assert y_mean.item() == pytest.approx(1.5)
# --------------------
# Test all_reduce MAX
# --------------------
y_max = all_reduce(y.clone(), op=ReduceOp.MAX)
assert y_max.item() == 2.0
dist.destroy_process_group()
@pytest.mark.runs_on(["npu", "cuda"])
@pytest.mark.require_distributed(2)
def test_distributed_ops(monkeypatch):
monkeypatch.setenv("MASTER_ADDR", "127.0.0.1")
monkeypatch.setenv("MASTER_PORT", str(find_available_port()))
WORLD_SIZE = 2
mp.spawn(
_dist_worker,
args=(WORLD_SIZE,),
nprocs=WORLD_SIZE,
join=True,
)
@pytest.mark.runs_on(["npu", "cuda"])
@pytest.mark.require_distributed(4)
def test_required_multi():
# test require_distributed mark ok
pass
@pytest.mark.runs_on(["npu", "cuda"])
@pytest.mark.require_distributed(999)
def test_required_invalid():
# test require_distributed mark not ok,
raise RuntimeError(
"this case should not be run, please check whether the require_distributed mark implementation is correct"
)

View File

@@ -12,15 +12,48 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import pytest
import torch.multiprocessing as mp
from llamafactory.v1.accelerator.helper import ReduceOp
from llamafactory.v1.accelerator.interface import DistributedInterface from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.utils.env import find_available_port
from llamafactory.v1.utils.pytest import dist_env
def test_distributed_interface(): def _all_reduce_tests(local_rank: int, world_size: int, master_port: int):
DistributedInterface() with dist_env(local_rank, world_size, master_port):
assert DistributedInterface.get_rank() == int(os.getenv("RANK", "0")) rank = DistributedInterface().get_rank()
assert DistributedInterface.get_world_size() == int(os.getenv("WORLD_SIZE", "1")) world_size = DistributedInterface().get_world_size()
assert DistributedInterface.get_local_rank() == int(os.getenv("LOCAL_RANK", "0")) assert world_size == 2
assert DistributedInterface.get_local_world_size() == int(os.getenv("LOCAL_WORLD_SIZE", "1"))
y_sum = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.SUM)
assert y_sum == pytest.approx(3.0)
y_mean = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.MEAN)
assert y_mean == pytest.approx(1.5)
y_max = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.MAX)
assert y_max == pytest.approx(2.0)
z = DistributedInterface().all_gather(rank + 1.0)
assert z == pytest.approx([1.0, 2.0])
z = DistributedInterface().broadcast(rank + 1.0)
assert z == pytest.approx(1.0)
def test_all_device():
assert DistributedInterface().get_rank() == int(os.getenv("RANK", "0"))
assert DistributedInterface().get_world_size() == int(os.getenv("WORLD_SIZE", "1"))
assert DistributedInterface().get_local_rank() == int(os.getenv("LOCAL_RANK", "0"))
assert DistributedInterface().get_local_world_size() == int(os.getenv("LOCAL_WORLD_SIZE", "1"))
@pytest.mark.runs_on(["cuda", "npu"])
@pytest.mark.require_distributed(2)
def test_multi_device():
master_port = find_available_port()
mp.spawn(_all_reduce_tests, args=(2, master_port), nprocs=2)

View File

@@ -18,20 +18,17 @@ Contains shared fixtures, pytest configuration, and custom markers.
""" """
import os import os
from typing import Optional
import pytest import pytest
from pytest import Config, Item from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.train.test_utils import patch_valuehead_model from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
from llamafactory.v1.accelerator.helper import get_current_device, get_device_count from llamafactory.v1.utils.env import is_env_enabled
from llamafactory.v1.utils.packages import is_transformers_version_greater_than from llamafactory.v1.utils.packages import is_transformers_version_greater_than
from llamafactory.v1.utils.utils import is_env_enabled
try: CURRENT_DEVICE = get_current_accelerator().type
CURRENT_DEVICE = get_current_device().type # cpu | cuda | npu
except Exception:
CURRENT_DEVICE = "cpu"
def pytest_configure(config: Config): def pytest_configure(config: Config):
@@ -67,26 +64,27 @@ def _handle_runs_on(items: list[Item]):
def _handle_slow_tests(items: list[Item]): def _handle_slow_tests(items: list[Item]):
"""Skip slow tests unless RUN_SLOW is enabled.""" """Skip slow tests unless RUN_SLOW is enabled."""
if not is_env_enabled("RUN_SLOW", "0"): if not is_env_enabled("RUN_SLOW"):
skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)") skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)")
for item in items: for item in items:
if "slow" in item.keywords: if "slow" in item.keywords:
item.add_marker(skip_slow) item.add_marker(skip_slow)
def _get_visible_devices_env(): def _get_visible_devices_env() -> Optional[str]:
"""Return device visibility env var name.""" """Return device visibility env var name."""
if CURRENT_DEVICE == "cuda": if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES" return "CUDA_VISIBLE_DEVICES"
if CURRENT_DEVICE == "npu": elif CURRENT_DEVICE == "npu":
return "ASCEND_RT_VISIBLE_DEVICES" return "ASCEND_RT_VISIBLE_DEVICES"
return None else:
return None
def _handle_device_visibility(items: list[Item]): def _handle_device_visibility(items: list[Item]):
"""Handle device visibility based on test markers.""" """Handle device visibility based on test markers."""
env_key = _get_visible_devices_env() env_key = _get_visible_devices_env()
if env_key is None or CURRENT_DEVICE == "cpu": if env_key is None or CURRENT_DEVICE in ("cpu", "mps"):
return return
# Parse visible devices # Parse visible devices
@@ -122,7 +120,7 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _manage_distributed_env(request, monkeypatch): def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested.""" """Set environment variables for distributed tests if specific devices are requested."""
env_key = _get_visible_devices_env() env_key = _get_visible_devices_env()
if not env_key: if not env_key:
@@ -132,8 +130,7 @@ def _manage_distributed_env(request, monkeypatch):
old_value = os.environ.get(env_key) old_value = os.environ.get(env_key)
marker = request.node.get_closest_marker("require_distributed") marker = request.node.get_closest_marker("require_distributed")
if marker: if marker: # distributed test
# Distributed test
required = marker.args[0] if marker.args else 2 required = marker.args[0] if marker.args else 2
specific_devices = marker.args[1] if len(marker.args) > 1 else None specific_devices = marker.args[1] if len(marker.args) > 1 else None
@@ -143,16 +140,9 @@ def _manage_distributed_env(request, monkeypatch):
devices_str = ",".join(str(i) for i in range(required)) devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str) monkeypatch.setenv(env_key, devices_str)
else: else: # non-distributed test
# Non-distributed test
if old_value: if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""] visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0") monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else: else:
monkeypatch.setenv(env_key, "0") monkeypatch.setenv(env_key, "0")
@pytest.fixture
def fix_valuehead_cpu_loading():
"""Fix valuehead model loading."""
patch_valuehead_model()

View File

@@ -28,10 +28,10 @@ from transformers import AutoTokenizer
from llamafactory.v1.config.data_args import DataArguments from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.data_loader import DataLoader
from llamafactory.v1.core.trainer_utils.data_collator import ( from llamafactory.v1.core.trainer_utils.data_collator import (
DefaultCollator, DefaultCollator,
) )
from llamafactory.v1.core.trainer_utils.data_loader import DataLoader
from llamafactory.v1.plugins.data_plugins.template import QwenTemplate from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
from llamafactory.v1.utils.batching_queue import TextBatchingQueue from llamafactory.v1.utils.batching_queue import TextBatchingQueue

View File

@@ -12,57 +12,56 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from llamafactory.v1.accelerator.helper import get_current_accelerator from llamafactory.v1.accelerator.helper import get_current_accelerator
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
class TestKernelPlugin(unittest.TestCase): @pytest.fixture(autouse=True)
@patch("torch.accelerator.current_accelerator") def clear_accelerator_cache():
def test_apply_kernel(self, mock_get_accelerator): get_current_accelerator.cache_clear()
get_current_accelerator.cache_clear()
mock_device = MagicMock()
mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
apply_kernel(model, npu_rope.NpuRoPEKernel)
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
class Test_Use_V1_Kernels(unittest.TestCase): @patch("torch.accelerator.current_accelerator")
@patch("torch.accelerator.current_accelerator") def test_apply_kernel(mock_get_accelerator: MagicMock):
def test_use_v1_kernels(self, mock_get_accelerator): mock_device = MagicMock()
get_current_accelerator.cache_clear() setattr(mock_device, "type", "npu")
mock_device = MagicMock() mock_get_accelerator.return_value = mock_device
mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward original_swiglu_forward = model.model.layers[0].mlp.forward
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels apply_kernel(model, npu_rope.NpuRoPEKernel)
model = apply_available_kernels(model) model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.model.layers[0].mlp.forward is not original_swiglu_forward assert model.model.layers[0].mlp.forward is not original_swiglu_forward
@patch("torch.accelerator.current_accelerator")
def test_apply_all_kernels(mock_get_accelerator: MagicMock):
get_current_accelerator.cache_clear()
mock_device = MagicMock()
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_available_kernels(model)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
assert model.model.layers[0].mlp.forward is not original_swiglu_forward