[fix] fix liger kernel patch for npu (#10583)

This commit is contained in:
jiaqiw09
2026-06-16 18:21:52 +08:00
committed by GitHub
parent 897a44386c
commit 8669a22e9c

View File

@@ -16,6 +16,7 @@ import inspect
from typing import TYPE_CHECKING
from ...extras import logging
from ...extras.misc import get_device_name
if TYPE_CHECKING:
@@ -99,5 +100,12 @@ def apply_liger_kernel(
else:
kwargs = {}
if get_device_name() == "npu":
import torch
if "Ascend910" not in torch.npu.get_device_name(0):
kwargs["swiglu"] = False
kwargs["fused_linear_cross_entropy"] = False
apply_liger_kernel(**kwargs)
logger.info_rank0("Liger kernel has been applied to the model.")