diff --git a/examples/v1/train_full/train_full_liger_kernel.yaml b/examples/v1/train_full/train_full_liger_kernel.yaml new file mode 100644 index 000000000..b6face99c --- /dev/null +++ b/examples/v1/train_full/train_full_liger_kernel.yaml @@ -0,0 +1,28 @@ +model: Qwen/Qwen3-0.6B +model_class: llm + +template: qwen3_nothink + +kernel_config: + name: liger_kernel + include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null + +quant_config: null + +dist_config: + name: fsdp2 + dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp + +### data +train_dataset: data/v1_sft_demo.yaml + +### training +output_dir: outputs/test_fsdp2 +micro_batch_size: 1 +cutoff_len: 2048 +learning_rate: 1.0e-4 +max_steps: 10 + +### sample +sample_backend: hf +max_new_tokens: 128 diff --git a/src/llamafactory/v1/core/model_engine.py b/src/llamafactory/v1/core/model_engine.py index 629e76f5b..24e97b0d7 100644 --- a/src/llamafactory/v1/core/model_engine.py +++ b/src/llamafactory/v1/core/model_engine.py @@ -188,9 +188,12 @@ class ModelEngine: if self.args.kernel_config is not None: from ..plugins.model_plugins.kernels.interface import KernelPlugin - model = KernelPlugin(self.args.kernel_config.name)( - model, include_kernels=self.args.kernel_config.get("include_kernels") - ) + kernel_config = self.args.kernel_config + kernel_kwargs: dict = {"model": model, "include_kernels": kernel_config.get("include_kernels")} + if kernel_config.name == "liger_kernel": + # Fused linear CE omits logits; SFT stage needs logits for loss_weights. + kernel_kwargs["require_logits"] = self.is_train + model = KernelPlugin(kernel_config.name)(**kernel_kwargs) return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/base.py b/src/llamafactory/v1/plugins/model_plugins/kernels/base.py index 265986ccc..3916562b1 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/base.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/base.py @@ -34,7 +34,7 @@ class BaseKernel(ABC): """ _kernel_id: Any = "" # kernel ID, any hashable value to identify a kernel implementation - _device: DeviceType = DeviceType.CPU # "cuda", "npu", "cpu", etc. + _device: list[DeviceType] = [DeviceType.CPU] # "cuda", "npu", "cpu", etc. @classmethod def get_kernel_id(cls) -> str: @@ -42,8 +42,8 @@ class BaseKernel(ABC): return cls._kernel_id @classmethod - def get_device(cls) -> str: - """Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu").""" + def get_device(cls) -> list[DeviceType]: + """Returns the device type list associated with the kernel (e.g., ["cuda", "npu", "cpu"]).""" return cls._device @classmethod @@ -58,7 +58,7 @@ class BaseKernel(ABC): it should raise an error instead of silently switching. Kernels can override this method to implement custom dependency checks. """ - if cls._device != get_current_accelerator().type: + if get_current_accelerator().type not in cls._device: return False return True diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py b/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py index 7967a4328..1646448e9 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py @@ -138,3 +138,48 @@ def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFMode apply_kernel(kernel, model=model) return model + + +@KernelPlugin("liger_kernel").register() +def apply_liger_kernels( + model: HFModel, + include_kernels: str = None, + require_logits: bool = False, +) -> HFModel: + """Applies Liger kernel to the model. + + Args: + model (HFModel): The model instance to apply kernels to. + include_kernels (str, optional): If ``"auto"`` or ``True``, apply Liger with + library defaults. If a comma-separated list (e.g. + ``rope,rms_norm``), enable only those ops; names match + ``apply_liger_kernel_to_*`` kwargs: ``rope``, ``rms_norm``, + ``swiglu``, ``cross_entropy``, ``fused_linear_cross_entropy``. + If ``None`` or ``False``, do nothing. Defaults to ``None``. + require_logits (bool, optional): When true, disables ``fused_linear_cross_entropy`` in favor + of non-fused CE so the forward pass returns ``logits``. Needed + for trainers that compute weighted loss from logits (e.g. v1 + SFT with ``loss_weights``). Defaults to ``False`` (fused CE + when supported). The v1 ``run_sft`` entrypoint sets + ``require_logits`` to true for ``liger_kernel`` when the key + is omitted so SFT weighted loss keeps working. + + Returns: + HFModel: The model with Liger kernel applied. + """ + if not include_kernels: + return model + if include_kernels == "auto" or include_kernels is True: + use_kernels = "auto" + else: + use_kernels = [k.strip() for k in include_kernels.split(",") if k.strip()] + if not use_kernels: + return model + + try: + from .liger_kernel_ops import LigerKernel + except ImportError as e: + logger.warning_rank0(f"[Kernel] Failed to import liger_kernel ops, skip. Error: {e}") + return model + + return LigerKernel.apply(use_kernels=use_kernels, model=model, require_logits=require_logits) diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/liger_kernel_ops.py b/src/llamafactory/v1/plugins/model_plugins/kernels/liger_kernel_ops.py new file mode 100644 index 000000000..1d6fd0dbd --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/liger_kernel_ops.py @@ -0,0 +1,148 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The definition of Liger Kernel. + +Init Phase: +1. Define LigerKernel class. +2. Register Liger kernel. + +""" + +import inspect + +from ....accelerator.helper import DeviceType, get_current_accelerator +from ....utils.logging import get_logger +from ....utils.types import HFModel +from .base import BaseKernel + + +logger = get_logger(__name__) + +_LIGER_FN_BY_MODEL_TYPE: dict[str, str] = { + "qwen3": "apply_liger_kernel_to_qwen3", + "qwen3_moe": "apply_liger_kernel_to_qwen3_moe", + "qwen3_next": "apply_liger_kernel_to_qwen3_next", + "qwen3_5": "apply_liger_kernel_to_qwen3_5", + "qwen3_5_text": "apply_liger_kernel_to_qwen3_5_text", + "qwen3_5_moe": "apply_liger_kernel_to_qwen3_5_moe", + "qwen3_5_moe_text": "apply_liger_kernel_to_qwen3_5_moe_text", +} + + +class LigerKernel(BaseKernel): + """Liger Kernel for optimized model training.""" + + _device = [DeviceType.CUDA, DeviceType.NPU] + + @classmethod + def check_deps(cls) -> bool: + """Checks if the required dependencies for the kernel are available.""" + try: + import liger_kernel # noqa: F401 + + return super().check_deps() + except ImportError: + logger.warning_rank0( + "Liger kernel is not installed, the kernel_config liger_kernel will be ignored. Please install it from https://github.com/linkedin/Liger-Kernel." + ) + return False + + @classmethod + def apply(cls, **kwargs) -> "HFModel": + """Applies the Liger kernel to the model. + + Args: + **kwargs: Must include ``model``. Optional ``use_kernels`` is a list of Liger op + names to enable exclusively, or the string ``"auto"`` to use each + ``apply_liger_kernel_to_*`` function's signature defaults (same as calling + upstream with only ``model``). Optional ``require_logits`` forces non-fused + cross entropy when supported. + + Returns: + HFModel: The model with Liger kernel applied. + + Raises: + ValueError: If the model is not provided. + RuntimeError: If dependencies are not met. + """ + model = kwargs.get("model") + use_kernels = kwargs.get("use_kernels", None) + if model is None: + raise ValueError(f"HFModel instance is required for {cls.__name__}.") + + if not cls.check_deps(): + raise RuntimeError( + f"current device is not supported by liger_kernel. Current device is {get_current_accelerator().type}, supported devices are {cls.get_device()}" + ) + + require_logits = kwargs.get("require_logits", False) + + model_type = getattr(model.config, "model_type", None) + + if model_type not in _LIGER_FN_BY_MODEL_TYPE: + logger.warning_rank0("Current model does not support liger kernel.") + return model + + import liger_kernel.transformers as liger_transformers + + apply_liger_kernel = getattr(liger_transformers, _LIGER_FN_BY_MODEL_TYPE[model_type]) + + sig = inspect.signature(apply_liger_kernel).parameters + togglable = [name for name in sig if name != "model"] + + def _normalize_op_name(raw: str) -> str: + key = raw.strip().lower().replace("-", "_") + aliases = { + "rmsnorm": "rms_norm", + "flce": "fused_linear_cross_entropy", + "lce": "fused_linear_cross_entropy", + "fused_ce": "fused_linear_cross_entropy", + } + return aliases.get(key, key) + + if use_kernels is not None and len(use_kernels) == 0: + return model + + if use_kernels != "auto": + selected = {_normalize_op_name(k) for k in use_kernels} + ops = selected - set(togglable) + if ops: + raise ValueError( + f"Unknown Liger op(s) {sorted(ops)} for model_type={model_type}. Valid: {sorted(togglable)}" + ) + if "cross_entropy" in selected and "fused_linear_cross_entropy" in selected: + raise ValueError("cross_entropy and fused_linear_cross_entropy cannot both be enabled.") + call_kwargs = {name: (name in selected) for name in togglable} + call_kwargs["model"] = model + else: + # Mirror ``liger_kernel`` signature defaults so patches match upstream defaults + # and logging reflects enabled ops (omitted kwargs only live in the callee). + call_kwargs = {"model": model} + for name in togglable: + param = sig[name] + if param.default is not inspect.Parameter.empty: + call_kwargs[name] = param.default + + if require_logits and "fused_linear_cross_entropy" in sig: + logger.warning_rank0("Current training stage does not support chunked cross entropy.") + call_kwargs["fused_linear_cross_entropy"] = False + call_kwargs["cross_entropy"] = True + + apply_liger_kernel(**call_kwargs) + + applied = sorted(name for name, on in call_kwargs.items() if name != "model" and on) + logger.info_rank0(f"These Liger ops are applied to the model: {applied}") + + return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py index 2621e4bad..2921f0aed 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py @@ -58,7 +58,7 @@ class Registry: device = kernel_cls.get_device() # The device type of the current accelerator does not match the device type required by the kernel, skip registration - if device != get_current_accelerator().type: + if get_current_accelerator().type not in device: return if not kernel_id: