mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 04:38:53 +08:00
[fix] fix liger kernel patch for npu (#10583)
This commit is contained in:
@@ -16,6 +16,7 @@ import inspect
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
|
from ...extras.misc import get_device_name
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -99,5 +100,12 @@ def apply_liger_kernel(
|
|||||||
else:
|
else:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
|
if get_device_name() == "npu":
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if "Ascend910" not in torch.npu.get_device_name(0):
|
||||||
|
kwargs["swiglu"] = False
|
||||||
|
kwargs["fused_linear_cross_entropy"] = False
|
||||||
|
|
||||||
apply_liger_kernel(**kwargs)
|
apply_liger_kernel(**kwargs)
|
||||||
logger.info_rank0("Liger kernel has been applied to the model.")
|
logger.info_rank0("Liger kernel has been applied to the model.")
|
||||||
|
|||||||
Reference in New Issue
Block a user