diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 0f95d3f2..860bf891 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -69,6 +69,9 @@ def configure_attn_implementation( if getattr(config, "model_type", None) == "internlm2": # special case for custom models setattr(config, "attn_implementation", requested_attn_implementation) + elif getattr(config, "model_type", None) == "kimi_vl": + setattr(config.vision_config, "_attn_implementation", requested_attn_implementation) + setattr(config.text_config, "_attn_implementation", requested_attn_implementation) else: setattr(config, "_attn_implementation", requested_attn_implementation)