diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index 81c1132d..9f9cd20d 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -45,6 +45,8 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel elif model_type == "qwen2": from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel + elif model_type == "qwen2_vl": + from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel else: logger.warning("Current model does not support liger kernel.") return