[v1] support liger_kernel (#10493)

This commit is contained in:
sunyi0505
2026-05-21 11:44:56 +08:00
committed by GitHub
parent 2322bf1cc2
commit 7e20db5735
6 changed files with 232 additions and 8 deletions

View File

@@ -0,0 +1,28 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: liger_kernel
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -188,9 +188,12 @@ class ModelEngine:
if self.args.kernel_config is not None: if self.args.kernel_config is not None:
from ..plugins.model_plugins.kernels.interface import KernelPlugin from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)( kernel_config = self.args.kernel_config
model, include_kernels=self.args.kernel_config.get("include_kernels") kernel_kwargs: dict = {"model": model, "include_kernels": kernel_config.get("include_kernels")}
) if kernel_config.name == "liger_kernel":
# Fused linear CE omits logits; SFT stage needs logits for loss_weights.
kernel_kwargs["require_logits"] = self.is_train
model = KernelPlugin(kernel_config.name)(**kernel_kwargs)
return model return model

View File

@@ -34,7 +34,7 @@ class BaseKernel(ABC):
""" """
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation _kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc. _device: list[DeviceType] = [DeviceType.CPU] # "cuda", "npu", "cpu", etc.
@classmethod @classmethod
def get_kernel_id(cls) -> str: def get_kernel_id(cls) -> str:
@@ -42,8 +42,8 @@ class BaseKernel(ABC):
return cls._kernel_id return cls._kernel_id
@classmethod @classmethod
def get_device(cls) -> str: def get_device(cls) -> list[DeviceType]:
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu").""" """Returns the device type list associated with the kernel (e.g., ["cuda", "npu", "cpu"])."""
return cls._device return cls._device
@classmethod @classmethod
@@ -58,7 +58,7 @@ class BaseKernel(ABC):
it should raise an error instead of silently switching. it should raise an error instead of silently switching.
Kernels can override this method to implement custom dependency checks. Kernels can override this method to implement custom dependency checks.
""" """
if cls._device != get_current_accelerator().type: if get_current_accelerator().type not in cls._device:
return False return False
return True return True

View File

@@ -138,3 +138,48 @@ def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFMode
apply_kernel(kernel, model=model) apply_kernel(kernel, model=model)
return model return model
@KernelPlugin("liger_kernel").register()
def apply_liger_kernels(
model: HFModel,
include_kernels: str = None,
require_logits: bool = False,
) -> HFModel:
"""Applies Liger kernel to the model.
Args:
model (HFModel): The model instance to apply kernels to.
include_kernels (str, optional): If ``"auto"`` or ``True``, apply Liger with
library defaults. If a comma-separated list (e.g.
``rope,rms_norm``), enable only those ops; names match
``apply_liger_kernel_to_*`` kwargs: ``rope``, ``rms_norm``,
``swiglu``, ``cross_entropy``, ``fused_linear_cross_entropy``.
If ``None`` or ``False``, do nothing. Defaults to ``None``.
require_logits (bool, optional): When true, disables ``fused_linear_cross_entropy`` in favor
of non-fused CE so the forward pass returns ``logits``. Needed
for trainers that compute weighted loss from logits (e.g. v1
SFT with ``loss_weights``). Defaults to ``False`` (fused CE
when supported). The v1 ``run_sft`` entrypoint sets
``require_logits`` to true for ``liger_kernel`` when the key
is omitted so SFT weighted loss keeps working.
Returns:
HFModel: The model with Liger kernel applied.
"""
if not include_kernels:
return model
if include_kernels == "auto" or include_kernels is True:
use_kernels = "auto"
else:
use_kernels = [k.strip() for k in include_kernels.split(",") if k.strip()]
if not use_kernels:
return model
try:
from .liger_kernel_ops import LigerKernel
except ImportError as e:
logger.warning_rank0(f"[Kernel] Failed to import liger_kernel ops, skip. Error: {e}")
return model
return LigerKernel.apply(use_kernels=use_kernels, model=model, require_logits=require_logits)

View File

@@ -0,0 +1,148 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of Liger Kernel.
Init Phase:
1. Define LigerKernel class.
2. Register Liger kernel.
"""
import inspect
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.logging import get_logger
from ....utils.types import HFModel
from .base import BaseKernel
logger = get_logger(__name__)
_LIGER_FN_BY_MODEL_TYPE: dict[str, str] = {
"qwen3": "apply_liger_kernel_to_qwen3",
"qwen3_moe": "apply_liger_kernel_to_qwen3_moe",
"qwen3_next": "apply_liger_kernel_to_qwen3_next",
"qwen3_5": "apply_liger_kernel_to_qwen3_5",
"qwen3_5_text": "apply_liger_kernel_to_qwen3_5_text",
"qwen3_5_moe": "apply_liger_kernel_to_qwen3_5_moe",
"qwen3_5_moe_text": "apply_liger_kernel_to_qwen3_5_moe_text",
}
class LigerKernel(BaseKernel):
"""Liger Kernel for optimized model training."""
_device = [DeviceType.CUDA, DeviceType.NPU]
@classmethod
def check_deps(cls) -> bool:
"""Checks if the required dependencies for the kernel are available."""
try:
import liger_kernel # noqa: F401
return super().check_deps()
except ImportError:
logger.warning_rank0(
"Liger kernel is not installed, the kernel_config liger_kernel will be ignored. Please install it from https://github.com/linkedin/Liger-Kernel."
)
return False
@classmethod
def apply(cls, **kwargs) -> "HFModel":
"""Applies the Liger kernel to the model.
Args:
**kwargs: Must include ``model``. Optional ``use_kernels`` is a list of Liger op
names to enable exclusively, or the string ``"auto"`` to use each
``apply_liger_kernel_to_*`` function's signature defaults (same as calling
upstream with only ``model``). Optional ``require_logits`` forces non-fused
cross entropy when supported.
Returns:
HFModel: The model with Liger kernel applied.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model = kwargs.get("model")
use_kernels = kwargs.get("use_kernels", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
raise RuntimeError(
f"current device is not supported by liger_kernel. Current device is {get_current_accelerator().type}, supported devices are {cls.get_device()}"
)
require_logits = kwargs.get("require_logits", False)
model_type = getattr(model.config, "model_type", None)
if model_type not in _LIGER_FN_BY_MODEL_TYPE:
logger.warning_rank0("Current model does not support liger kernel.")
return model
import liger_kernel.transformers as liger_transformers
apply_liger_kernel = getattr(liger_transformers, _LIGER_FN_BY_MODEL_TYPE[model_type])
sig = inspect.signature(apply_liger_kernel).parameters
togglable = [name for name in sig if name != "model"]
def _normalize_op_name(raw: str) -> str:
key = raw.strip().lower().replace("-", "_")
aliases = {
"rmsnorm": "rms_norm",
"flce": "fused_linear_cross_entropy",
"lce": "fused_linear_cross_entropy",
"fused_ce": "fused_linear_cross_entropy",
}
return aliases.get(key, key)
if use_kernels is not None and len(use_kernels) == 0:
return model
if use_kernels != "auto":
selected = {_normalize_op_name(k) for k in use_kernels}
ops = selected - set(togglable)
if ops:
raise ValueError(
f"Unknown Liger op(s) {sorted(ops)} for model_type={model_type}. Valid: {sorted(togglable)}"
)
if "cross_entropy" in selected and "fused_linear_cross_entropy" in selected:
raise ValueError("cross_entropy and fused_linear_cross_entropy cannot both be enabled.")
call_kwargs = {name: (name in selected) for name in togglable}
call_kwargs["model"] = model
else:
# Mirror ``liger_kernel`` signature defaults so patches match upstream defaults
# and logging reflects enabled ops (omitted kwargs only live in the callee).
call_kwargs = {"model": model}
for name in togglable:
param = sig[name]
if param.default is not inspect.Parameter.empty:
call_kwargs[name] = param.default
if require_logits and "fused_linear_cross_entropy" in sig:
logger.warning_rank0("Current training stage does not support chunked cross entropy.")
call_kwargs["fused_linear_cross_entropy"] = False
call_kwargs["cross_entropy"] = True
apply_liger_kernel(**call_kwargs)
applied = sorted(name for name, on in call_kwargs.items() if name != "model" and on)
logger.info_rank0(f"These Liger ops are applied to the model: {applied}")
return model

View File

@@ -58,7 +58,7 @@ class Registry:
device = kernel_cls.get_device() device = kernel_cls.get_device()
# The device type of the current accelerator does not match the device type required by the kernel, skip registration # The device type of the current accelerator does not match the device type required by the kernel, skip registration
if device != get_current_accelerator().type: if get_current_accelerator().type not in device:
return return
if not kernel_id: if not kernel_id: