diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 45275392..cfeecd86 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -211,10 +211,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if ( self.model is not None and getattr(self.model.config, "model_type", None) - in ["glm4v", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"] + in ["glm4v", "Keye", "qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"] and ("position_ids" not in features or features["position_ids"].dim() != 3) ): - raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.") + raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.") if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled cross_attention_mask = mm_inputs.pop("cross_attention_mask") diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index a7868abc..7e9cb3bf 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1171,6 +1171,24 @@ register_template( ) +# copied from qwen template +register_template( + name="keye_vl", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), + stop_words=["<|im_end|>"], + replace_eos=True, + mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), + template_class=ReasoningTemplate, +) + + register_template( name="kimi_vl", format_user=StringFormatter( diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 0cf91b60..439963f7 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1235,6 +1235,18 @@ register_model_group( ) +register_model_group( + models={ + "Keye-VL-8B-Chat": { + DownloadSource.DEFAULT: "Kwai-Keye/Keye-VL-8B-Preview", + DownloadSource.MODELSCOPE: "Kwai-Keye/Keye-VL-8B-Preview", + }, + }, + template="keye_vl", + multimodal=True, +) + + register_model_group( models={ "Kimi-Dev-72B-Instruct": { diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index f66c2e76..1e228855 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -226,6 +226,15 @@ _register_composite_model( ) +_register_composite_model( + model_type="Keye", + projector_key="mlp_AR", + vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"], + language_model_keys=["model", "lm_head"], + lora_conflict_keys=["patch_embedding"], +) + + _register_composite_model( model_type="llama4", vision_model_keys=["vision_model"],