[misc] fix accelerator (#9661)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yaowei Zheng
2025-12-25 02:11:04 +08:00
committed by GitHub
parent 6a2eafbae3
commit a754604c11
44 changed files with 396 additions and 448 deletions

View File

@@ -94,9 +94,8 @@ def configure_quantization(
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if (
quant_method not in (QuantizationMethod.MXFP4 and QuantizationMethod.FP8)
and (is_deepspeed_zero3_enabled() or is_fsdp_enabled())
if quant_method not in (QuantizationMethod.MXFP4, QuantizationMethod.FP8) and (
is_deepspeed_zero3_enabled() or is_fsdp_enabled()
):
# mxfp4 will dequant the model weights
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")

View File

@@ -15,11 +15,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions used by the distributed interface.
Including:
- Environment info (rank, world_size, local_rank, etc.)
- Accelerator info (device type, device count, etc.)
- Collective communication operations (all_gather, all_reduce, broadcast)
- Synchronize processes and ensure main-process-first execution order
"""
import os
from contextlib import contextmanager
from enum import Enum, unique
from functools import lru_cache
from typing import Optional
from functools import lru_cache, wraps
from typing import Callable, Optional
import numpy as np
import torch
@@ -46,6 +55,22 @@ class ReduceOp(str, Enum):
MIN = "min"
def requires_accelerator(fn):
"""Decorator to check if torch.accelerator is available.
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
"""
@wraps(fn)
def wrapper(*args, **kwargs):
if not hasattr(torch, "accelerator"):
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
return fn(*args, **kwargs)
return wrapper
def is_distributed() -> bool:
"""Check if distributed environment is available."""
return os.getenv("RANK") is not None
@@ -72,105 +97,105 @@ def get_local_world_size() -> int:
@lru_cache
@requires_accelerator
def get_current_accelerator(check_available: bool = True) -> torch.device:
"""Get current accelerator.
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
"""
if not hasattr(torch, "accelerator"):
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
"""Get current accelerator."""
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
if accelerator is None:
return torch.device(DeviceType.CPU.value)
return accelerator or torch.device(DeviceType.CPU.value)
return accelerator
@lru_cache
@requires_accelerator
def get_device_count() -> int:
"""Get the number of available devices."""
return torch.accelerator.device_count()
@requires_accelerator
def synchronize() -> None:
"""Synchronize all processes."""
torch.accelerator.synchronize()
@requires_accelerator
def set_device() -> None:
"""Set current accelerator."""
torch.accelerator.set_device_index(get_local_rank())
def is_torch_cuda_available():
"""Check if CUDA is available."""
return get_current_accelerator().type == DeviceType.CUDA
def is_torch_mps_available():
"""Check if MPS is available."""
return get_current_accelerator().type == DeviceType.MPS
def is_torch_npu_available():
"""Check if NPU is available."""
return get_current_accelerator().type == DeviceType.NPU
def is_torch_xpu_available():
"""Check if XPU is available."""
return get_current_accelerator().type == DeviceType.XPU
def get_current_device() -> "torch.device":
r"""Get the current available device."""
if is_torch_xpu_available():
device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_npu_available():
device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_mps_available():
device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_cuda_available():
device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs) -> TensorLike:
"""Operate tensorlike data on current accelerator."""
device = get_current_accelerator()
is_tensor = isinstance(data, torch.Tensor)
is_ndarray = isinstance(data, np.ndarray)
if is_tensor:
orig_device = data.device
data = data.to(device=device)
elif is_ndarray:
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
else:
device = "cpu"
data = torch.tensor(data, dtype=torch.float, device=device)
return torch.device(device)
result = fn(data, **kwargs)
def get_device_count() -> int:
r"""Get the number of available devices."""
if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
return torch.npu.device_count()
elif is_torch_mps_available():
return torch.mps.device_count()
elif is_torch_cuda_available():
return torch.cuda.device_count()
if is_tensor:
return result.to(orig_device)
elif is_ndarray:
return result.cpu().numpy()
elif result.numel() == 1:
return result.item()
else:
return 0
return result.tolist()
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
"""Gathers the tensor from all ranks and concats them along the first dim."""
"""Gathers the tensor from all ranks and stacks them at the first dim."""
world_size = get_world_size()
device = get_current_accelerator()
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device)
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=tensor.device)
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
return output_tensor.view(-1, *tensor.size()[1:])
return output_tensor.view(-1, *tensor.size())
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike:
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> Tensor:
"""Performs all reduce in the given process group."""
device = get_current_accelerator()
is_ndarray = isinstance(data, np.ndarray)
is_tensor = isinstance(data, torch.Tensor)
if is_ndarray:
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
elif not is_tensor:
data = torch.tensor(data, dtype=torch.float, device=device)
reduce_ops = {
ReduceOp.MEAN: dist.ReduceOp.SUM,
ReduceOp.SUM: dist.ReduceOp.SUM,
ReduceOp.MAX: dist.ReduceOp.MAX,
ReduceOp.MIN: dist.ReduceOp.MIN,
}
dist.all_reduce(data, op=reduce_ops[op], group=group)
dist.all_reduce(tensor, op=reduce_ops[op], group=group)
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
data /= dist.get_world_size(group=group)
tensor /= dist.get_world_size(group=group)
if is_tensor:
return data
elif is_ndarray:
return data.cpu().numpy()
elif data.numel() == 1:
return data.item()
else:
return data.tolist()
return tensor
def broadcast(tensor: Tensor, src: int = 0, group: Optional[ProcessGroup] = None) -> Tensor:
"""Broadcasts the tensor from the src process to all other processes."""
dist.broadcast(tensor, src=src, group=group)
return tensor
@contextmanager

View File

@@ -15,26 +15,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""A unified interface for model parallelism and data parallelism.
Supports model parallelism types:
- mp_replicate: Replicate model across multiple devices.
- mp_shard: Shard model across multiple devices.
And data parallelism types:
- dp: Data parallelism.
- cp: Context parallelism.
"""
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from typing import Any, Optional
from torch.distributed import init_process_group
from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from .helper import (
ReduceOp,
all_gather,
all_reduce,
get_current_accelerator,
get_local_rank,
get_local_world_size,
get_rank,
get_world_size,
is_distributed,
)
from . import helper
class Dim(str, Enum):
@@ -60,24 +61,24 @@ class DistributedStrategy:
"""Context parallel size, default to 1."""
def __post_init__(self) -> None:
if not is_distributed():
if not helper.is_distributed():
self.mp_shard_size = 1
elif self.mp_shard_size is None:
self.mp_shard_size = get_world_size() // self.mp_replicate_size
elif self.mp_replicate_size * self.mp_shard_size != get_world_size():
self.mp_shard_size = helper.get_world_size() // self.mp_replicate_size
elif self.mp_replicate_size * self.mp_shard_size != helper.get_world_size():
raise ValueError(
f"mp_replicate_size * mp_shard_size must equal to world_size, "
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {get_world_size()}."
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {helper.get_world_size()}."
)
if not is_distributed():
if not helper.is_distributed():
self.dp_size = 1
elif self.dp_size is None:
self.dp_size = get_world_size() // self.cp_size
elif self.dp_size * self.cp_size != get_world_size():
self.dp_size = helper.get_world_size() // self.cp_size
elif self.dp_size * self.cp_size != helper.get_world_size():
raise ValueError(
f"dp_size * cp_size must equal to world_size, "
f"got {self.dp_size} * {self.cp_size} != {get_world_size()}."
f"got {self.dp_size} * {self.cp_size} != {helper.get_world_size()}."
)
@property
@@ -106,20 +107,6 @@ class DistributedInterface:
_instance: Optional["DistributedInterface"] = None
_initialized: bool = False
_is_distributed = is_distributed()
_rank = get_rank()
_world_size = get_world_size()
_local_rank = get_local_rank()
_local_world_size = get_local_world_size()
strategy: Optional[DistributedStrategy] = None
"""Distributed strategy."""
model_device_mesh: Optional[DeviceMesh] = None
"""Model parallel device mesh."""
data_device_mesh: Optional[DeviceMesh] = None
"""Data parallel device mesh."""
current_accelerator = get_current_accelerator()
"""Current accelerator."""
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
"""Singleton pattern."""
@@ -132,6 +119,14 @@ class DistributedInterface:
if self._initialized:
return
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()
self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size()
self.current_accelerator = helper.get_current_accelerator()
self.device_count = helper.get_device_count()
if config is None:
self.strategy = DistributedStrategy()
timeout = 18000
@@ -145,6 +140,7 @@ class DistributedInterface:
timeout = config.get("timeout", 18000)
if self._is_distributed:
helper.set_device()
init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
@@ -169,65 +165,84 @@ class DistributedInterface:
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
)
@classmethod
def get_device_mesh(cls, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
def get_device_mesh(self, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
"""Get device mesh for specified dimension."""
if dim is None:
raise ValueError("dim must be specified.")
elif cls.model_device_mesh is None:
elif self.model_device_mesh is None:
return None
elif dim in cls.strategy.data_mesh_dim_names:
return cls.data_device_mesh[dim.value]
elif dim in self.strategy.data_mesh_dim_names:
return self.data_device_mesh[dim.value]
else:
return cls.model_device_mesh[dim.value]
return self.model_device_mesh[dim.value]
@classmethod
def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
def get_group(self, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension."""
if cls.model_device_mesh is None or dim is None:
if self.model_device_mesh is None or dim is None:
return None
else:
return cls.get_device_mesh(dim).get_group()
return self.get_device_mesh(dim).get_group()
@classmethod
def get_rank(cls, dim: Optional[Dim] = None) -> int:
def get_rank(self, dim: Optional[Dim] = None) -> int:
"""Get parallel rank for specified dimension."""
if cls.model_device_mesh is None:
if self.model_device_mesh is None:
return 0
elif dim is None:
return cls._rank
return self._rank
else:
return cls.get_device_mesh(dim).get_local_rank()
return self.get_device_mesh(dim).get_local_rank()
@classmethod
def get_world_size(cls, dim: Optional[Dim] = None) -> int:
def get_world_size(self, dim: Optional[Dim] = None) -> int:
"""Get parallel size for specified dimension."""
if cls.model_device_mesh is None:
if self.model_device_mesh is None:
return 1
elif dim is None:
return cls._world_size
return self._world_size
else:
return cls.get_device_mesh(dim).size()
return self.get_device_mesh(dim).size()
@classmethod
def get_local_rank(cls) -> int:
def get_local_rank(self) -> int:
"""Get parallel local rank."""
return cls._local_rank
return self._local_rank
@classmethod
def get_local_world_size(cls) -> int:
def get_local_world_size(self) -> int:
"""Get parallel local world size."""
return cls._local_world_size
return self._local_world_size
@classmethod
def all_gather(cls, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
def all_gather(self, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
"""Gather tensor across specified parallel group."""
return all_gather(data, cls.get_group(dim)) if cls.model_device_mesh is not None else data
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
else:
return data
@classmethod
def all_reduce(cls, data: TensorLike, op: ReduceOp = ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP) -> TensorLike:
def all_reduce(
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP
) -> TensorLike:
"""Reduce tensor across specified parallel group."""
return all_reduce(data, op, cls.get_group(dim)) if cls.model_device_mesh is not None else data
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
else:
return data
def broadcast(self, data: TensorLike, src: int = 0, dim: Optional[Dim] = Dim.DP) -> TensorLike:
"""Broadcast tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
else:
return data
def sync(self) -> None:
"""Synchronize all processes."""
helper.synchronize()
def barrier(self) -> None:
"""Barrier all processes."""
barrier()
def destroy(self) -> None:
"""Destroy all processes."""
destroy_process_group()
if __name__ == "__main__":

View File

@@ -97,7 +97,7 @@ class ModelLoader:
self.args.model,
config=self.model_config,
dtype="auto",
device_map=DistributedInterface.current_accelerator,
device_map=DistributedInterface().current_accelerator,
trust_remote_code=self.args.trust_remote_code,
)

View File

@@ -22,10 +22,10 @@ from typing import Optional
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ..utils.batching_queue import BaseBatchingQueue
from ..utils.logging import get_logger
from ..utils.types import Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator
from ...utils.batching_queue import BaseBatchingQueue
from ...utils.logging import get_logger
from ...utils.types import Processor, TorchDataset
from .data_collator import DataCollator
logger = get_logger(__name__)

View File

@@ -17,7 +17,7 @@ import socket
def find_available_port() -> int:
r"""Find an available port on the local machine."""
"""Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
@@ -26,9 +26,5 @@ def find_available_port() -> int:
def is_env_enabled(env_var: str, default: str = "0") -> bool:
r"""Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
if __name__ == "__main__":
print(find_available_port())
"""Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"]

View File

@@ -0,0 +1,35 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
@contextmanager
def dist_env(local_rank: int = 0, world_size: int = 1, master_port: int = 25595):
"""Set distributed environment variables."""
env_vars = {
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": str(master_port),
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"LOCAL_WORLD_SIZE": str(world_size),
}
os.environ.update(env_vars)
try:
yield
finally:
for key in env_vars.keys():
os.environ.pop(key, None)

View File

@@ -18,19 +18,17 @@ Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
from typing import Optional
import pytest
from pytest import Config, Item
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import patch_valuehead_model
try:
CURRENT_DEVICE = get_current_device().type # cpu | cuda | npu
except Exception:
CURRENT_DEVICE = "cpu"
CURRENT_DEVICE = get_current_device().type
def pytest_configure(config: Config):
@@ -66,26 +64,27 @@ def _handle_runs_on(items: list[Item]):
def _handle_slow_tests(items: list[Item]):
"""Skip slow tests unless RUN_SLOW is enabled."""
if not is_env_enabled("RUN_SLOW", "0"):
if not is_env_enabled("RUN_SLOW"):
skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
def _get_visible_devices_env():
def _get_visible_devices_env() -> Optional[str]:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
if CURRENT_DEVICE == "npu":
elif CURRENT_DEVICE == "npu":
return "ASCEND_RT_VISIBLE_DEVICES"
return None
else:
return None
def _handle_device_visibility(items: list[Item]):
"""Handle device visibility based on test markers."""
env_key = _get_visible_devices_env()
if env_key is None or CURRENT_DEVICE == "cpu":
if env_key is None or CURRENT_DEVICE in ("cpu", "mps"):
return
# Parse visible devices
@@ -121,7 +120,7 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
@pytest.fixture(autouse=True)
def _manage_distributed_env(request, monkeypatch):
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
env_key = _get_visible_devices_env()
if not env_key:
@@ -131,8 +130,7 @@ def _manage_distributed_env(request, monkeypatch):
old_value = os.environ.get(env_key)
marker = request.node.get_closest_marker("require_distributed")
if marker:
# Distributed test
if marker: # distributed test
required = marker.args[0] if marker.args else 2
specific_devices = marker.args[1] if len(marker.args) > 1 else None
@@ -142,8 +140,7 @@ def _manage_distributed_env(request, monkeypatch):
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
else:
# Non-distributed test
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")

View File

@@ -42,7 +42,7 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_feedback_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]

View File

@@ -51,7 +51,7 @@ def _convert_sharegpt_to_openai(messages: list[dict[str, str]]) -> list[dict[str
return new_messages
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_pairwise_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]

View File

@@ -18,7 +18,7 @@ import pytest
from llamafactory.data.processor.processor_utils import infer_seqlen
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize(
"test_input,test_output",
[

View File

@@ -42,7 +42,7 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_supervised_single_turn(num_samples: int):
train_dataset = load_dataset_module(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)["train_dataset"]
@@ -62,7 +62,7 @@ def test_supervised_single_turn(num_samples: int):
assert train_dataset["input_ids"][index] == ref_input_ids
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [8])
def test_supervised_multi_turn(num_samples: int):
train_dataset = load_dataset_module(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)[
@@ -76,7 +76,7 @@ def test_supervised_multi_turn(num_samples: int):
assert train_dataset["input_ids"][index] == ref_input_ids
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [4])
def test_supervised_train_on_prompt(num_samples: int):
train_dataset = load_dataset_module(
@@ -91,7 +91,7 @@ def test_supervised_train_on_prompt(num_samples: int):
assert train_dataset["labels"][index] == ref_ids
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [4])
def test_supervised_mask_history(num_samples: int):
train_dataset = load_dataset_module(

View File

@@ -46,7 +46,7 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_unsupervised_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]

View File

@@ -29,7 +29,7 @@ from llamafactory.model import load_tokenizer
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_base_collator():
model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA3, "template": "default"})
tokenizer_module = load_tokenizer(model_args)
@@ -73,7 +73,7 @@ def test_base_collator():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_multimodal_collator():
model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}

View File

@@ -20,7 +20,7 @@ from llamafactory.data.parser import DatasetAttr
from llamafactory.hparams import DataArguments
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_alpaca_converter():
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
data_args = DataArguments()
@@ -41,7 +41,7 @@ def test_alpaca_converter():
}
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_sharegpt_converter():
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
data_args = DataArguments()

View File

@@ -38,19 +38,19 @@ TOOLS = [
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_string_formatter():
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION)
@@ -60,7 +60,7 @@ def test_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -71,7 +71,7 @@ def test_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default")
assert formatter.apply(content=json.dumps(TOOLS)) == [
@@ -90,14 +90,14 @@ def test_default_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = (
@@ -110,14 +110,14 @@ def test_default_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4")
assert formatter.apply(content=json.dumps(TOOLS)) == [
@@ -128,14 +128,14 @@ def test_glm4_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps(FUNCTION)
@@ -144,7 +144,7 @@ def test_llama3_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -155,7 +155,7 @@ def test_llama3_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y")
@@ -169,14 +169,14 @@ def test_llama3_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_multi_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = (
@@ -189,7 +189,7 @@ def test_llama3_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
@@ -199,7 +199,7 @@ def test_mistral_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -211,7 +211,7 @@ def test_mistral_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_tool_formatter():
formatter = ToolFormatter(tool_format="mistral")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
@@ -220,14 +220,14 @@ def test_mistral_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_multi_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = (
@@ -240,7 +240,7 @@ def test_mistral_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
@@ -249,7 +249,7 @@ def test_qwen_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -260,7 +260,7 @@ def test_qwen_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_tool_formatter():
formatter = ToolFormatter(tool_format="qwen")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
@@ -274,14 +274,14 @@ def test_qwen_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_multi_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = (

View File

@@ -40,21 +40,21 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_load_train_only():
dataset_module = load_dataset_module(**TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None
assert dataset_module.get("eval_dataset") is None
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_load_val_size():
dataset_module = load_dataset_module(val_size=0.1, **TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None
assert dataset_module.get("eval_dataset") is not None
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_load_eval_data():
dataset_module = load_dataset_module(eval_dataset=TINY_DATA, **TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None

View File

@@ -179,7 +179,7 @@ def _check_plugin(
)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_base_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3)
base_plugin = get_mm_plugin(name="base")
@@ -187,7 +187,7 @@ def test_base_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_gemma3_plugin():
@@ -210,7 +210,7 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_internvl_plugin():
image_seqlen = 256
@@ -229,7 +229,7 @@ def test_internvl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0")
def test_llama4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
@@ -251,7 +251,7 @@ def test_llama4_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_plugin():
image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
@@ -265,7 +265,7 @@ def test_llava_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_next_plugin():
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
@@ -279,7 +279,7 @@ def test_llava_next_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_next_video_plugin():
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
@@ -293,7 +293,7 @@ def test_llava_next_video_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin():
image_seqlen = 256
@@ -313,7 +313,7 @@ def test_paligemma_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2
@@ -336,7 +336,7 @@ def test_pixtral_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_qwen2_omni_plugin():
image_seqlen, audio_seqlen = 4, 2
@@ -367,7 +367,7 @@ def test_qwen2_omni_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen2_vl_plugin():
image_seqlen = 4
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
@@ -384,7 +384,7 @@ def test_qwen2_vl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
def test_qwen3_vl_plugin():
frame_seqlen = 1
@@ -406,7 +406,7 @@ def test_qwen3_vl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_video_llava_plugin():
image_seqlen = 256

View File

@@ -89,7 +89,7 @@ def _check_template(
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
@@ -105,7 +105,7 @@ def test_encode_oneturn(use_fast: bool):
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
@@ -127,7 +127,7 @@ def test_encode_multiturn(use_fast: bool):
)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
@@ -154,7 +154,7 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
@@ -184,7 +184,7 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t
)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
@@ -195,7 +195,7 @@ def test_jinja_template(use_fast: bool):
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_ollama_modelfile():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
@@ -213,14 +213,14 @@ def test_ollama_modelfile():
)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_get_stop_token_ids():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009}
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool):
@@ -234,7 +234,7 @@ def test_gemma_template(use_fast: bool):
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma2_template(use_fast: bool):
@@ -248,7 +248,7 @@ def test_gemma2_template(use_fast: bool):
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool):
@@ -262,7 +262,7 @@ def test_llama3_template(use_fast: bool):
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize(
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
)
@@ -284,7 +284,7 @@ def test_llama4_template(use_fast: bool):
pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
],
)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_phi4_template(use_fast: bool):
prompt_str = (
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
@@ -296,7 +296,7 @@ def test_phi4_template(use_fast: bool):
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen2_5_template(use_fast: bool):
@@ -311,7 +311,7 @@ def test_qwen2_5_template(use_fast: bool):
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool):
@@ -331,7 +331,7 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool):
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_parse_llama3_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
template = parse_template(tokenizer)
@@ -345,7 +345,7 @@ def test_parse_llama3_template():
assert template.default_system == ""
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
@@ -358,7 +358,7 @@ def test_parse_qwen_template():
assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)

View File

@@ -37,13 +37,13 @@ MESSAGES = [
EXPECTED_RESPONSE = "_rho"
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_chat():
chat_model = ChatModel(INFER_ARGS)
assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_stream_chat():
chat_model = ChatModel(INFER_ARGS)
response = ""

View File

@@ -39,7 +39,7 @@ MESSAGES = [
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cuda"])
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_chat():
r"""Test the SGLang engine's basic chat functionality."""
@@ -49,7 +49,7 @@ def test_chat():
print(response.response_text)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cuda"])
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_stream_chat():
r"""Test the SGLang engine's streaming chat functionality."""

View File

@@ -49,7 +49,7 @@ INFER_ARGS = {
OS_NAME = os.getenv("OS_NAME", "")
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize(
"stage,dataset",
[
@@ -66,7 +66,7 @@ def test_run_exp(stage: str, dataset: str):
assert os.path.exists(output_dir)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_export():
export_dir = os.path.join("output", "llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS})

View File

@@ -17,7 +17,7 @@ import pytest
from llamafactory.eval.template import get_eval_template
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_eval_template_en():
support_set = [
{
@@ -56,7 +56,7 @@ def test_eval_template_en():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_eval_template_zh():
support_set = [
{

View File

@@ -25,7 +25,6 @@ TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
UNUSED_TOKEN = "<|UNUSED_TOKEN|>"
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("special_tokens", [False, True])
def test_add_tokens(special_tokens: bool):
if special_tokens:

View File

@@ -39,7 +39,6 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.xfail(is_transformers_version_greater_than("4.48"), reason="Attention refactor.")
def test_attention():
attention_available = ["disabled"]

View File

@@ -39,7 +39,6 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("disable_gradient_checkpointing", [False, True])
def test_vanilla_checkpointing(disable_gradient_checkpointing: bool):
model = load_train_model(disable_gradient_checkpointing=disable_gradient_checkpointing, **TRAIN_ARGS)
@@ -47,14 +46,12 @@ def test_vanilla_checkpointing(disable_gradient_checkpointing: bool):
assert getattr(module, "gradient_checkpointing") != disable_gradient_checkpointing
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_unsloth_gradient_checkpointing():
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_upcast_layernorm():
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
for name, param in model.named_parameters():
@@ -62,7 +59,6 @@ def test_upcast_layernorm():
assert param.dtype == torch.float32
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_upcast_lmhead_output():
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())

View File

@@ -24,7 +24,6 @@ from llamafactory.model.model_utils.misc import find_expanded_modules
HF_TOKEN = os.getenv("HF_TOKEN")
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_expanded_modules():
config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

View File

@@ -18,7 +18,6 @@ import torch
from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize(
"attention_mask,golden_seq_lens",
[

View File

@@ -23,7 +23,6 @@ from llamafactory.hparams import FinetuningArguments, ModelArguments
from llamafactory.model.adapter import init_adapter
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("freeze_vision_tower", (False, True))
@pytest.mark.parametrize("freeze_multi_modal_projector", (False, True))
@pytest.mark.parametrize("freeze_language_model", (False, True))
@@ -49,7 +48,6 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
assert param.requires_grad != freeze_language_model
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("freeze_vision_tower,freeze_language_model", ((False, False), (False, True), (True, False)))
def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
@@ -82,7 +80,6 @@ def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
assert (merger_param_name in trainable_params) is False
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_visual_model_save_load():
# check VLM's state dict: https://github.com/huggingface/transformers/pull/38385
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")

View File

@@ -30,14 +30,12 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_base():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA3)
compare_model(model, ref_model)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_valuehead():
model = load_infer_model(add_valuehead=True, **INFER_ARGS)

View File

@@ -14,7 +14,6 @@
import os
import pytest
import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
@@ -44,7 +43,6 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_freeze_train_all_modules():
model = load_train_model(freeze_trainable_layers=1, **TRAIN_ARGS)
for name, param in model.named_parameters():
@@ -56,7 +54,6 @@ def test_freeze_train_all_modules():
assert param.dtype == torch.float16
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_freeze_train_extra_modules():
model = load_train_model(freeze_trainable_layers=1, freeze_extra_modules="embed_tokens,lm_head", **TRAIN_ARGS)
for name, param in model.named_parameters():
@@ -68,7 +65,6 @@ def test_freeze_train_extra_modules():
assert param.dtype == torch.float16
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_freeze_inference():
model = load_infer_model(**INFER_ARGS)
for param in model.parameters():

View File

@@ -14,7 +14,6 @@
import os
import pytest
import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
@@ -44,7 +43,6 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_full_train():
model = load_train_model(**TRAIN_ARGS)
for param in model.parameters():
@@ -52,7 +50,6 @@ def test_full_train():
assert param.dtype == torch.float32
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_full_inference():
model = load_infer_model(**INFER_ARGS)
for param in model.parameters():

View File

@@ -55,35 +55,30 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_qv_modules():
model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "v_proj"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_all_modules():
model = load_train_model(lora_target="all", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_extra_modules():
model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS)
_, extra_modules = check_lora_model(model)
assert extra_modules == {"embed_tokens", "lm_head"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_old_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model(model, ref_model)
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_new_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
@@ -92,7 +87,6 @@ def test_lora_train_new_adapters():
)
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_lora_train_valuehead():
model = load_train_model(add_valuehead=True, **TRAIN_ARGS)
@@ -103,7 +97,6 @@ def test_lora_train_valuehead():
assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"])
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_inference():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()

View File

@@ -49,7 +49,6 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")
def test_pissa_train():
model = load_train_model(**TRAIN_ARGS)
@@ -57,7 +56,6 @@ def test_pissa_train():
compare_model(model, ref_model)
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.xfail(reason="Known connection error.")
def test_pissa_inference():
model = load_infer_model(**INFER_ARGS)

View File

@@ -59,7 +59,6 @@ class DataCollatorWithVerbose(DataCollatorWithPadding):
return {k: v[:, :1] for k, v in batch.items()} # truncate input length
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("disable_shuffling", [False, True])
def test_shuffle(disable_shuffling: bool):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(

View File

@@ -1,18 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
runs_on = pytest.mark.runs_on

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated
0.9.4.104
0.9.4.105

View File

@@ -1,93 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from llamafactory.v1.accelerator.helper import ReduceOp, all_reduce, is_torch_cuda_available, is_torch_npu_available
from llamafactory.v1.utils.utils import find_available_port
def _dist_worker(rank, world_size):
if is_torch_cuda_available():
backend = "nccl"
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(rank)
elif is_torch_npu_available():
backend = "hccl"
device = torch.device(f"npu:{rank}")
torch.npu.set_device(rank)
else:
backend = "gloo"
device = torch.device("cpu")
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
)
# --------------------
# Test all_reduce SUM
# --------------------
y = torch.tensor(rank + 1.0, device=device)
y_sum = all_reduce(y.clone(), op=ReduceOp.SUM)
assert y_sum.item() == 3.0
# --------------------
# Test all_reduce MEAN
# --------------------
y_mean = all_reduce(y.clone(), op=ReduceOp.MEAN)
assert y_mean.item() == pytest.approx(1.5)
# --------------------
# Test all_reduce MAX
# --------------------
y_max = all_reduce(y.clone(), op=ReduceOp.MAX)
assert y_max.item() == 2.0
dist.destroy_process_group()
@pytest.mark.runs_on(["npu", "cuda"])
@pytest.mark.require_distributed(2)
def test_distributed_ops(monkeypatch):
monkeypatch.setenv("MASTER_ADDR", "127.0.0.1")
monkeypatch.setenv("MASTER_PORT", str(find_available_port()))
WORLD_SIZE = 2
mp.spawn(
_dist_worker,
args=(WORLD_SIZE,),
nprocs=WORLD_SIZE,
join=True,
)
@pytest.mark.runs_on(["npu", "cuda"])
@pytest.mark.require_distributed(4)
def test_required_multi():
# test require_distributed mark ok
pass
@pytest.mark.runs_on(["npu", "cuda"])
@pytest.mark.require_distributed(999)
def test_required_invalid():
# test require_distributed mark not ok,
raise RuntimeError(
"this case should not be run, please check whether the require_distributed mark implementation is correct"
)

View File

@@ -12,15 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch.multiprocessing as mp
from llamafactory.v1.accelerator.helper import ReduceOp
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.utils.env import find_available_port
from llamafactory.v1.utils.pytest import dist_env
def test_distributed_interface():
DistributedInterface()
assert DistributedInterface.get_rank() == int(os.getenv("RANK", "0"))
assert DistributedInterface.get_world_size() == int(os.getenv("WORLD_SIZE", "1"))
assert DistributedInterface.get_local_rank() == int(os.getenv("LOCAL_RANK", "0"))
assert DistributedInterface.get_local_world_size() == int(os.getenv("LOCAL_WORLD_SIZE", "1"))
def _all_reduce_tests(local_rank: int, world_size: int, master_port: int):
with dist_env(local_rank, world_size, master_port):
rank = DistributedInterface().get_rank()
world_size = DistributedInterface().get_world_size()
assert world_size == 2
y_sum = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.SUM)
assert y_sum == pytest.approx(3.0)
y_mean = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.MEAN)
assert y_mean == pytest.approx(1.5)
y_max = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.MAX)
assert y_max == pytest.approx(2.0)
z = DistributedInterface().all_gather(rank + 1.0)
assert z == pytest.approx([1.0, 2.0])
z = DistributedInterface().broadcast(rank + 1.0)
assert z == pytest.approx(1.0)
def test_all_device():
assert DistributedInterface().get_rank() == int(os.getenv("RANK", "0"))
assert DistributedInterface().get_world_size() == int(os.getenv("WORLD_SIZE", "1"))
assert DistributedInterface().get_local_rank() == int(os.getenv("LOCAL_RANK", "0"))
assert DistributedInterface().get_local_world_size() == int(os.getenv("LOCAL_WORLD_SIZE", "1"))
@pytest.mark.runs_on(["cuda", "npu"])
@pytest.mark.require_distributed(2)
def test_multi_device():
master_port = find_available_port()
mp.spawn(_all_reduce_tests, args=(2, master_port), nprocs=2)

View File

@@ -18,20 +18,17 @@ Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
from typing import Optional
import pytest
from pytest import Config, Item
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.train.test_utils import patch_valuehead_model
from llamafactory.v1.accelerator.helper import get_current_device, get_device_count
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
from llamafactory.v1.utils.env import is_env_enabled
from llamafactory.v1.utils.packages import is_transformers_version_greater_than
from llamafactory.v1.utils.utils import is_env_enabled
try:
CURRENT_DEVICE = get_current_device().type # cpu | cuda | npu
except Exception:
CURRENT_DEVICE = "cpu"
CURRENT_DEVICE = get_current_accelerator().type
def pytest_configure(config: Config):
@@ -67,26 +64,27 @@ def _handle_runs_on(items: list[Item]):
def _handle_slow_tests(items: list[Item]):
"""Skip slow tests unless RUN_SLOW is enabled."""
if not is_env_enabled("RUN_SLOW", "0"):
if not is_env_enabled("RUN_SLOW"):
skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
def _get_visible_devices_env():
def _get_visible_devices_env() -> Optional[str]:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
if CURRENT_DEVICE == "npu":
elif CURRENT_DEVICE == "npu":
return "ASCEND_RT_VISIBLE_DEVICES"
return None
else:
return None
def _handle_device_visibility(items: list[Item]):
"""Handle device visibility based on test markers."""
env_key = _get_visible_devices_env()
if env_key is None or CURRENT_DEVICE == "cpu":
if env_key is None or CURRENT_DEVICE in ("cpu", "mps"):
return
# Parse visible devices
@@ -122,7 +120,7 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
@pytest.fixture(autouse=True)
def _manage_distributed_env(request, monkeypatch):
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
env_key = _get_visible_devices_env()
if not env_key:
@@ -132,8 +130,7 @@ def _manage_distributed_env(request, monkeypatch):
old_value = os.environ.get(env_key)
marker = request.node.get_closest_marker("require_distributed")
if marker:
# Distributed test
if marker: # distributed test
required = marker.args[0] if marker.args else 2
specific_devices = marker.args[1] if len(marker.args) > 1 else None
@@ -143,16 +140,9 @@ def _manage_distributed_env(request, monkeypatch):
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
else:
# Non-distributed test
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else:
monkeypatch.setenv(env_key, "0")
@pytest.fixture
def fix_valuehead_cpu_loading():
"""Fix valuehead model loading."""
patch_valuehead_model()

View File

@@ -28,10 +28,10 @@ from transformers import AutoTokenizer
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.data_loader import DataLoader
from llamafactory.v1.core.trainer_utils.data_collator import (
DefaultCollator,
)
from llamafactory.v1.core.trainer_utils.data_loader import DataLoader
from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
from llamafactory.v1.utils.batching_queue import TextBatchingQueue

View File

@@ -12,57 +12,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import MagicMock, patch
import pytest
from transformers import AutoModelForCausalLM
from llamafactory.v1.accelerator.helper import get_current_accelerator
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
class TestKernelPlugin(unittest.TestCase):
@patch("torch.accelerator.current_accelerator")
def test_apply_kernel(self, mock_get_accelerator):
get_current_accelerator.cache_clear()
mock_device = MagicMock()
mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
apply_kernel(model, npu_rope.NpuRoPEKernel)
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
@pytest.fixture(autouse=True)
def clear_accelerator_cache():
get_current_accelerator.cache_clear()
class Test_Use_V1_Kernels(unittest.TestCase):
@patch("torch.accelerator.current_accelerator")
def test_use_v1_kernels(self, mock_get_accelerator):
get_current_accelerator.cache_clear()
mock_device = MagicMock()
mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device
@patch("torch.accelerator.current_accelerator")
def test_apply_kernel(mock_get_accelerator: MagicMock):
mock_device = MagicMock()
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels
apply_kernel(model, npu_rope.NpuRoPEKernel)
model = apply_available_kernels(model)
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
@patch("torch.accelerator.current_accelerator")
def test_apply_all_kernels(mock_get_accelerator: MagicMock):
get_current_accelerator.cache_clear()
mock_device = MagicMock()
setattr(mock_device, "type", "npu")
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_available_kernels(model)
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
assert model.model.layers[0].mlp.forward is not original_swiglu_forward