From d38c402f63bb6804b340e7cac1540e8113bb6231 Mon Sep 17 00:00:00 2001 From: Xiaosu Zhu Date: Tue, 25 Mar 2025 11:58:52 +0800 Subject: [PATCH] [misc] update liger-kernel's monkey patch (#7453) * Update liger_kernel.py * Update setup.py --- setup.py | 2 +- .../model/model_utils/liger_kernel.py | 35 +------------------ 2 files changed, 2 insertions(+), 35 deletions(-) diff --git a/setup.py b/setup.py index 6fe7180d..9053d483 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ extra_require = { "torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"], "metrics": ["nltk", "jieba", "rouge-chinese"], "deepspeed": ["deepspeed>=0.10.0,<=0.16.4"], - "liger-kernel": ["liger-kernel"], + "liger-kernel": ["liger-kernel>=0.5.5"], "bitsandbytes": ["bitsandbytes>=0.39.0"], "hqq": ["hqq"], "eetq": ["eetq"], diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index 84a69535..4f6cbf2f 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -27,39 +27,6 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def apply_liger_kernel_to_qwen2_5_vl( - rope: bool = True, - cross_entropy: bool = False, - fused_linear_cross_entropy: bool = True, - rms_norm: bool = True, - swiglu: bool = True, -) -> None: - from liger_kernel.transformers import LigerCrossEntropyLoss, LigerRMSNorm, LigerSwiGLUMLP - from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward - from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb - from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl - - def get_dtype(self: "modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel"): - return self.dtype - - modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel.get_dtype = get_dtype - - if rope: - modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb - - if rms_norm: - modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm - - if cross_entropy: - modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss - - if fused_linear_cross_entropy: - modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_lce_forward - - if swiglu: - modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP - - def apply_liger_kernel( config: "PretrainedConfig", model_args: "ModelArguments", @@ -95,7 +62,7 @@ def apply_liger_kernel( elif model_type == "qwen2_vl": from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel elif model_type == "qwen2_5_vl": - apply_liger_kernel = apply_liger_kernel_to_qwen2_5_vl + from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel else: logger.warning_rank0("Current model does not support liger kernel.") return