diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index 486c54e6..3f467752 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -45,16 +45,24 @@ def apply_liger_kernel( from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel elif model_type == "gemma3_text": from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel - elif model_type == "paligemma": - from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel + elif model_type == "glm4": + from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel + elif model_type == "granite": + from liger_kernel.transformers import apply_liger_kernel_to_granite as apply_liger_kernel elif model_type == "llama": from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel + elif model_type == "llava": + from liger_kernel.transformers import apply_liger_kernel_to_llava as apply_liger_kernel elif model_type == "mistral": from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel elif model_type == "mixtral": from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel elif model_type == "mllama": from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel + elif model_type == "olmo2": + from liger_kernel.transformers import apply_liger_kernel_to_olmo2 as apply_liger_kernel + elif model_type == "paligemma": + from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel elif model_type == "phi3": from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel elif model_type == "qwen2": @@ -63,6 +71,8 @@ def apply_liger_kernel( from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel elif model_type == "qwen2_5_vl": from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel + elif model_type == "qwen3": + from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel else: logger.warning_rank0("Current model does not support liger kernel.") return