[model] add liger kernel to qwen2_5 vl (#6930)

* add liger kernel to qwen2_5 vl

* fix patch

* fix patch

Former-commit-id: 797043d29cb85a8f90fabf48976908037f07000e
This commit is contained in:
hoshi-hiyouga 2025-02-13 23:05:54 +08:00 committed by GitHub
parent 48173b606c
commit cd493b91de

View File

@ -27,6 +27,39 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def apply_liger_kernel_to_qwen2_5_vl(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
) -> None:
from liger_kernel.transformers import LigerCrossEntropyLoss, LigerRMSNorm, LigerSwiGLUMLP
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
def get_dtype(self: "modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel"):
return self.dtype
modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel.get_dtype = get_dtype
if rope:
modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
if rms_norm:
modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
if cross_entropy:
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_lce_forward
if swiglu:
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
def apply_liger_kernel(
config: "PretrainedConfig",
model_args: "ModelArguments",
@ -47,19 +80,23 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "mllama":
from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel
elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel
elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
elif model_type == "qwen2_5_vl":
apply_liger_kernel = apply_liger_kernel_to_qwen2_5_vl
else:
logger.warning_rank0("Current model does not support liger kernel.")
return
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info_rank0("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False}
kwargs = {"fused_linear_cross_entropy": False, "cross_entropy": True}
else:
kwargs = {}