[v1] support liger_kernel (#10493)

This commit is contained in:
sunyi0505
2026-05-21 11:44:56 +08:00
committed by GitHub
parent 2322bf1cc2
commit 7e20db5735
6 changed files with 232 additions and 8 deletions

View File

@@ -0,0 +1,28 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: liger_kernel
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -188,9 +188,12 @@ class ModelEngine:
if self.args.kernel_config is not None:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model, include_kernels=self.args.kernel_config.get("include_kernels")
)
kernel_config = self.args.kernel_config
kernel_kwargs: dict = {"model": model, "include_kernels": kernel_config.get("include_kernels")}
if kernel_config.name == "liger_kernel":
# Fused linear CE omits logits; SFT stage needs logits for loss_weights.
kernel_kwargs["require_logits"] = self.is_train
model = KernelPlugin(kernel_config.name)(**kernel_kwargs)
return model

View File

@@ -34,7 +34,7 @@ class BaseKernel(ABC):
"""
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
_device: list[DeviceType] = [DeviceType.CPU] # "cuda", "npu", "cpu", etc.
@classmethod
def get_kernel_id(cls) -> str:
@@ -42,8 +42,8 @@ class BaseKernel(ABC):
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
def get_device(cls) -> list[DeviceType]:
"""Returns the device type list associated with the kernel (e.g., ["cuda", "npu", "cpu"])."""
return cls._device
@classmethod
@@ -58,7 +58,7 @@ class BaseKernel(ABC):
it should raise an error instead of silently switching.
Kernels can override this method to implement custom dependency checks.
"""
if cls._device != get_current_accelerator().type:
if get_current_accelerator().type not in cls._device:
return False
return True

View File

@@ -138,3 +138,48 @@ def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFMode
apply_kernel(kernel, model=model)
return model
@KernelPlugin("liger_kernel").register()
def apply_liger_kernels(
model: HFModel,
include_kernels: str = None,
require_logits: bool = False,
) -> HFModel:
"""Applies Liger kernel to the model.
Args:
model (HFModel): The model instance to apply kernels to.
include_kernels (str, optional): If ``"auto"`` or ``True``, apply Liger with
library defaults. If a comma-separated list (e.g.
``rope,rms_norm``), enable only those ops; names match
``apply_liger_kernel_to_*`` kwargs: ``rope``, ``rms_norm``,
``swiglu``, ``cross_entropy``, ``fused_linear_cross_entropy``.
If ``None`` or ``False``, do nothing. Defaults to ``None``.
require_logits (bool, optional): When true, disables ``fused_linear_cross_entropy`` in favor
of non-fused CE so the forward pass returns ``logits``. Needed
for trainers that compute weighted loss from logits (e.g. v1
SFT with ``loss_weights``). Defaults to ``False`` (fused CE
when supported). The v1 ``run_sft`` entrypoint sets
``require_logits`` to true for ``liger_kernel`` when the key
is omitted so SFT weighted loss keeps working.
Returns:
HFModel: The model with Liger kernel applied.
"""
if not include_kernels:
return model
if include_kernels == "auto" or include_kernels is True:
use_kernels = "auto"
else:
use_kernels = [k.strip() for k in include_kernels.split(",") if k.strip()]
if not use_kernels:
return model
try:
from .liger_kernel_ops import LigerKernel
except ImportError as e:
logger.warning_rank0(f"[Kernel] Failed to import liger_kernel ops, skip. Error: {e}")
return model
return LigerKernel.apply(use_kernels=use_kernels, model=model, require_logits=require_logits)

View File

@@ -0,0 +1,148 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of Liger Kernel.
Init Phase:
1. Define LigerKernel class.
2. Register Liger kernel.
"""
import inspect
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.logging import get_logger
from ....utils.types import HFModel
from .base import BaseKernel
logger = get_logger(__name__)
_LIGER_FN_BY_MODEL_TYPE: dict[str, str] = {
"qwen3": "apply_liger_kernel_to_qwen3",
"qwen3_moe": "apply_liger_kernel_to_qwen3_moe",
"qwen3_next": "apply_liger_kernel_to_qwen3_next",
"qwen3_5": "apply_liger_kernel_to_qwen3_5",
"qwen3_5_text": "apply_liger_kernel_to_qwen3_5_text",
"qwen3_5_moe": "apply_liger_kernel_to_qwen3_5_moe",
"qwen3_5_moe_text": "apply_liger_kernel_to_qwen3_5_moe_text",
}
class LigerKernel(BaseKernel):
"""Liger Kernel for optimized model training."""
_device = [DeviceType.CUDA, DeviceType.NPU]
@classmethod
def check_deps(cls) -> bool:
"""Checks if the required dependencies for the kernel are available."""
try:
import liger_kernel # noqa: F401
return super().check_deps()
except ImportError:
logger.warning_rank0(
"Liger kernel is not installed, the kernel_config liger_kernel will be ignored. Please install it from https://github.com/linkedin/Liger-Kernel."
)
return False
@classmethod
def apply(cls, **kwargs) -> "HFModel":
"""Applies the Liger kernel to the model.
Args:
**kwargs: Must include ``model``. Optional ``use_kernels`` is a list of Liger op
names to enable exclusively, or the string ``"auto"`` to use each
``apply_liger_kernel_to_*`` function's signature defaults (same as calling
upstream with only ``model``). Optional ``require_logits`` forces non-fused
cross entropy when supported.
Returns:
HFModel: The model with Liger kernel applied.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model")
use_kernels = kwargs.get("use_kernels", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError(
f"current device is not supported by liger_kernel. Current device is {get_current_accelerator().type}, supported devices are {cls.get_device()}"
)
require_logits = kwargs.get("require_logits", False)
model_type = getattr(model.config, "model_type", None)
if model_type not in _LIGER_FN_BY_MODEL_TYPE:
logger.warning_rank0("Current model does not support liger kernel.")
return model
import liger_kernel.transformers as liger_transformers
apply_liger_kernel = getattr(liger_transformers, _LIGER_FN_BY_MODEL_TYPE[model_type])
sig = inspect.signature(apply_liger_kernel).parameters
togglable = [name for name in sig if name != "model"]
def _normalize_op_name(raw: str) -> str:
key = raw.strip().lower().replace("-", "_")
aliases = {
"rmsnorm": "rms_norm",
"flce": "fused_linear_cross_entropy",
"lce": "fused_linear_cross_entropy",
"fused_ce": "fused_linear_cross_entropy",
}
return aliases.get(key, key)
if use_kernels is not None and len(use_kernels) == 0:
return model
if use_kernels != "auto":
selected = {_normalize_op_name(k) for k in use_kernels}
ops = selected - set(togglable)
if ops:
raise ValueError(
f"Unknown Liger op(s) {sorted(ops)} for model_type={model_type}. Valid: {sorted(togglable)}"
)
if "cross_entropy" in selected and "fused_linear_cross_entropy" in selected:
raise ValueError("cross_entropy and fused_linear_cross_entropy cannot both be enabled.")
call_kwargs = {name: (name in selected) for name in togglable}
call_kwargs["model"] = model
else:
# Mirror ``liger_kernel`` signature defaults so patches match upstream defaults
# and logging reflects enabled ops (omitted kwargs only live in the callee).
call_kwargs = {"model": model}
for name in togglable:
param = sig[name]
if param.default is not inspect.Parameter.empty:
call_kwargs[name] = param.default
if require_logits and "fused_linear_cross_entropy" in sig:
logger.warning_rank0("Current training stage does not support chunked cross entropy.")
call_kwargs["fused_linear_cross_entropy"] = False
call_kwargs["cross_entropy"] = True
apply_liger_kernel(**call_kwargs)
applied = sorted(name for name, on in call_kwargs.items() if name != "model" and on)
logger.info_rank0(f"These Liger ops are applied to the model: {applied}")
return model

View File

@@ -58,7 +58,7 @@ class Registry:
device = kernel_cls.get_device()
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
if device != get_current_accelerator().type:
if get_current_accelerator().type not in device:
return
if not kernel_id: