diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index 16623873..7912c6fe 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -27,6 +27,39 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) +def apply_liger_kernel_to_qwen2_5_vl( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, +) -> None: + from liger_kernel.transformers import LigerCrossEntropyLoss, LigerRMSNorm, LigerSwiGLUMLP + from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward + from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb + from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl + + def get_dtype(self: "modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel"): + return self.dtype + + modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel.get_dtype = get_dtype + + if rope: + modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb + + if rms_norm: + modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm + + if cross_entropy: + modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss + + if fused_linear_cross_entropy: + modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_lce_forward + + if swiglu: + modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP + + def apply_liger_kernel( config: "PretrainedConfig", model_args: "ModelArguments", @@ -47,19 +80,23 @@ def apply_liger_kernel( from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel elif model_type == "mixtral": from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel + elif model_type == "mllama": + from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel elif model_type == "phi3": from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel elif model_type == "qwen2": from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel elif model_type == "qwen2_vl": from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel + elif model_type == "qwen2_5_vl": + apply_liger_kernel = apply_liger_kernel_to_qwen2_5_vl else: logger.warning_rank0("Current model does not support liger kernel.") return if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters: logger.info_rank0("Current training stage does not support chunked cross entropy.") - kwargs = {"fused_linear_cross_entropy": False} + kwargs = {"fused_linear_cross_entropy": False, "cross_entropy": True} else: kwargs = {}