mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[v1] support liger_kernel (#10493)
This commit is contained in:
28
examples/v1/train_full/train_full_liger_kernel.yaml
Normal file
28
examples/v1/train_full/train_full_liger_kernel.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: liger_kernel
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -188,9 +188,12 @@ class ModelEngine:
|
||||
if self.args.kernel_config is not None:
|
||||
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
||||
|
||||
model = KernelPlugin(self.args.kernel_config.name)(
|
||||
model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
)
|
||||
kernel_config = self.args.kernel_config
|
||||
kernel_kwargs: dict = {"model": model, "include_kernels": kernel_config.get("include_kernels")}
|
||||
if kernel_config.name == "liger_kernel":
|
||||
# Fused linear CE omits logits; SFT stage needs logits for loss_weights.
|
||||
kernel_kwargs["require_logits"] = self.is_train
|
||||
model = KernelPlugin(kernel_config.name)(**kernel_kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class BaseKernel(ABC):
|
||||
"""
|
||||
|
||||
_kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation
|
||||
_device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc.
|
||||
_device: list[DeviceType] = [DeviceType.CPU] # "cuda", "npu", "cpu", etc.
|
||||
|
||||
@classmethod
|
||||
def get_kernel_id(cls) -> str:
|
||||
@@ -42,8 +42,8 @@ class BaseKernel(ABC):
|
||||
return cls._kernel_id
|
||||
|
||||
@classmethod
|
||||
def get_device(cls) -> str:
|
||||
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||
def get_device(cls) -> list[DeviceType]:
|
||||
"""Returns the device type list associated with the kernel (e.g., ["cuda", "npu", "cpu"])."""
|
||||
return cls._device
|
||||
|
||||
@classmethod
|
||||
@@ -58,7 +58,7 @@ class BaseKernel(ABC):
|
||||
it should raise an error instead of silently switching.
|
||||
Kernels can override this method to implement custom dependency checks.
|
||||
"""
|
||||
if cls._device != get_current_accelerator().type:
|
||||
if get_current_accelerator().type not in cls._device:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -138,3 +138,48 @@ def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFMode
|
||||
apply_kernel(kernel, model=model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@KernelPlugin("liger_kernel").register()
|
||||
def apply_liger_kernels(
|
||||
model: HFModel,
|
||||
include_kernels: str = None,
|
||||
require_logits: bool = False,
|
||||
) -> HFModel:
|
||||
"""Applies Liger kernel to the model.
|
||||
|
||||
Args:
|
||||
model (HFModel): The model instance to apply kernels to.
|
||||
include_kernels (str, optional): If ``"auto"`` or ``True``, apply Liger with
|
||||
library defaults. If a comma-separated list (e.g.
|
||||
``rope,rms_norm``), enable only those ops; names match
|
||||
``apply_liger_kernel_to_*`` kwargs: ``rope``, ``rms_norm``,
|
||||
``swiglu``, ``cross_entropy``, ``fused_linear_cross_entropy``.
|
||||
If ``None`` or ``False``, do nothing. Defaults to ``None``.
|
||||
require_logits (bool, optional): When true, disables ``fused_linear_cross_entropy`` in favor
|
||||
of non-fused CE so the forward pass returns ``logits``. Needed
|
||||
for trainers that compute weighted loss from logits (e.g. v1
|
||||
SFT with ``loss_weights``). Defaults to ``False`` (fused CE
|
||||
when supported). The v1 ``run_sft`` entrypoint sets
|
||||
``require_logits`` to true for ``liger_kernel`` when the key
|
||||
is omitted so SFT weighted loss keeps working.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with Liger kernel applied.
|
||||
"""
|
||||
if not include_kernels:
|
||||
return model
|
||||
if include_kernels == "auto" or include_kernels is True:
|
||||
use_kernels = "auto"
|
||||
else:
|
||||
use_kernels = [k.strip() for k in include_kernels.split(",") if k.strip()]
|
||||
if not use_kernels:
|
||||
return model
|
||||
|
||||
try:
|
||||
from .liger_kernel_ops import LigerKernel
|
||||
except ImportError as e:
|
||||
logger.warning_rank0(f"[Kernel] Failed to import liger_kernel ops, skip. Error: {e}")
|
||||
return model
|
||||
|
||||
return LigerKernel.apply(use_kernels=use_kernels, model=model, require_logits=require_logits)
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of Liger Kernel.
|
||||
|
||||
Init Phase:
|
||||
1. Define LigerKernel class.
|
||||
2. Register Liger kernel.
|
||||
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils.types import HFModel
|
||||
from .base import BaseKernel
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_LIGER_FN_BY_MODEL_TYPE: dict[str, str] = {
|
||||
"qwen3": "apply_liger_kernel_to_qwen3",
|
||||
"qwen3_moe": "apply_liger_kernel_to_qwen3_moe",
|
||||
"qwen3_next": "apply_liger_kernel_to_qwen3_next",
|
||||
"qwen3_5": "apply_liger_kernel_to_qwen3_5",
|
||||
"qwen3_5_text": "apply_liger_kernel_to_qwen3_5_text",
|
||||
"qwen3_5_moe": "apply_liger_kernel_to_qwen3_5_moe",
|
||||
"qwen3_5_moe_text": "apply_liger_kernel_to_qwen3_5_moe_text",
|
||||
}
|
||||
|
||||
|
||||
class LigerKernel(BaseKernel):
|
||||
"""Liger Kernel for optimized model training."""
|
||||
|
||||
_device = [DeviceType.CUDA, DeviceType.NPU]
|
||||
|
||||
@classmethod
|
||||
def check_deps(cls) -> bool:
|
||||
"""Checks if the required dependencies for the kernel are available."""
|
||||
try:
|
||||
import liger_kernel # noqa: F401
|
||||
|
||||
return super().check_deps()
|
||||
except ImportError:
|
||||
logger.warning_rank0(
|
||||
"Liger kernel is not installed, the kernel_config liger_kernel will be ignored. Please install it from https://github.com/linkedin/Liger-Kernel."
|
||||
)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
"""Applies the Liger kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Must include ``model``. Optional ``use_kernels`` is a list of Liger op
|
||||
names to enable exclusively, or the string ``"auto"`` to use each
|
||||
``apply_liger_kernel_to_*`` function's signature defaults (same as calling
|
||||
upstream with only ``model``). Optional ``require_logits`` forces non-fused
|
||||
cross entropy when supported.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with Liger kernel applied.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not provided.
|
||||
RuntimeError: If dependencies are not met.
|
||||
"""
|
||||
model = kwargs.get("model")
|
||||
use_kernels = kwargs.get("use_kernels", None)
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(
|
||||
f"current device is not supported by liger_kernel. Current device is {get_current_accelerator().type}, supported devices are {cls.get_device()}"
|
||||
)
|
||||
|
||||
require_logits = kwargs.get("require_logits", False)
|
||||
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
|
||||
if model_type not in _LIGER_FN_BY_MODEL_TYPE:
|
||||
logger.warning_rank0("Current model does not support liger kernel.")
|
||||
return model
|
||||
|
||||
import liger_kernel.transformers as liger_transformers
|
||||
|
||||
apply_liger_kernel = getattr(liger_transformers, _LIGER_FN_BY_MODEL_TYPE[model_type])
|
||||
|
||||
sig = inspect.signature(apply_liger_kernel).parameters
|
||||
togglable = [name for name in sig if name != "model"]
|
||||
|
||||
def _normalize_op_name(raw: str) -> str:
|
||||
key = raw.strip().lower().replace("-", "_")
|
||||
aliases = {
|
||||
"rmsnorm": "rms_norm",
|
||||
"flce": "fused_linear_cross_entropy",
|
||||
"lce": "fused_linear_cross_entropy",
|
||||
"fused_ce": "fused_linear_cross_entropy",
|
||||
}
|
||||
return aliases.get(key, key)
|
||||
|
||||
if use_kernels is not None and len(use_kernels) == 0:
|
||||
return model
|
||||
|
||||
if use_kernels != "auto":
|
||||
selected = {_normalize_op_name(k) for k in use_kernels}
|
||||
ops = selected - set(togglable)
|
||||
if ops:
|
||||
raise ValueError(
|
||||
f"Unknown Liger op(s) {sorted(ops)} for model_type={model_type}. Valid: {sorted(togglable)}"
|
||||
)
|
||||
if "cross_entropy" in selected and "fused_linear_cross_entropy" in selected:
|
||||
raise ValueError("cross_entropy and fused_linear_cross_entropy cannot both be enabled.")
|
||||
call_kwargs = {name: (name in selected) for name in togglable}
|
||||
call_kwargs["model"] = model
|
||||
else:
|
||||
# Mirror ``liger_kernel`` signature defaults so patches match upstream defaults
|
||||
# and logging reflects enabled ops (omitted kwargs only live in the callee).
|
||||
call_kwargs = {"model": model}
|
||||
for name in togglable:
|
||||
param = sig[name]
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
call_kwargs[name] = param.default
|
||||
|
||||
if require_logits and "fused_linear_cross_entropy" in sig:
|
||||
logger.warning_rank0("Current training stage does not support chunked cross entropy.")
|
||||
call_kwargs["fused_linear_cross_entropy"] = False
|
||||
call_kwargs["cross_entropy"] = True
|
||||
|
||||
apply_liger_kernel(**call_kwargs)
|
||||
|
||||
applied = sorted(name for name, on in call_kwargs.items() if name != "model" and on)
|
||||
logger.info_rank0(f"These Liger ops are applied to the model: {applied}")
|
||||
|
||||
return model
|
||||
@@ -58,7 +58,7 @@ class Registry:
|
||||
device = kernel_cls.get_device()
|
||||
|
||||
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
|
||||
if device != get_current_accelerator().type:
|
||||
if get_current_accelerator().type not in device:
|
||||
return
|
||||
|
||||
if not kernel_id:
|
||||
|
||||
Reference in New Issue
Block a user