diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 162f432c9..d1e562d93 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -213,6 +213,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): and getattr(self.model.config, "model_type", None) in [ "glm4v", + "glm_ocr", "Keye", "qwen2_vl", "qwen2_5_vl", diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 7a923de3e..1132d7111 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1061,6 +1061,22 @@ register_template( ) +# copied from glm4 template +register_template( + name="glm_ocr", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"), +) + + # copied from glm4_moe template register_template( name="glm4_7", diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index cd9791e13..6d5afdb5f 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -950,6 +950,18 @@ register_model_group( ) +register_model_group( + models={ + "GLM-OCR": { + DownloadSource.DEFAULT: "zai-org/GLM-OCR", + DownloadSource.MODELSCOPE: "ZhipuAI/GLM-OCR", + }, + }, + template="glm_ocr", + multimodal=True, +) + + register_model_group( models={ "GLM-Z1-0414-9B-Chat": { diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 0d23b6e23..ae091b76d 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -239,6 +239,15 @@ _register_composite_model( ) +_register_composite_model( + model_type="glm_ocr", + projector_key="visual.merger", + vision_model_keys=["visual.patch_embed", "visual.blocks"], + language_model_keys=["language_model", "lm_head"], + lora_conflict_keys=["patch_embed"], +) + + _register_composite_model( model_type="internvl", )