From 80fe3a172d697277a403dd7da6a4b7013070f31a Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Sun, 21 Sep 2025 23:34:19 +0800 Subject: [PATCH] [model] add dots ocr (#9176) --- src/llamafactory/data/mm_plugin.py | 10 +++++++--- src/llamafactory/data/template.py | 17 +++++++++++++++++ src/llamafactory/extras/constants.py | 12 ++++++++++++ src/llamafactory/model/model_utils/visual.py | 9 +++++++++ 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index a3a6a4d9..607fe325 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1397,6 +1397,9 @@ class Qwen2AudioPlugin(BasePlugin): @dataclass class Qwen2VLPlugin(BasePlugin): + start_token: str = "<|vision_start|>" + end_token: str = "<|vision_end|>" + @override def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": image = super()._preprocess_image(image, **kwargs) @@ -1512,14 +1515,14 @@ class Qwen2VLPlugin(BasePlugin): while IMAGE_PLACEHOLDER in content: image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 content = content.replace( - IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1 + IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1 ) num_image_tokens += 1 while VIDEO_PLACEHOLDER in content: video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 content = content.replace( - VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1 + VIDEO_PLACEHOLDER, f"{self.start_token}{self.video_token * video_seqlen}{self.end_token}", 1 ) num_video_tokens += 1 @@ -1907,9 +1910,10 @@ def get_mm_plugin( image_token: Optional[str] = None, video_token: Optional[str] = None, audio_token: Optional[str] = None, + **kwargs, ) -> "BasePlugin": r"""Get plugin for multimodal inputs.""" if name not in PLUGINS: raise ValueError(f"Multimodal plugin `{name}` not found.") - return PLUGINS[name](image_token, video_token, audio_token) + return PLUGINS[name](image_token, video_token, audio_token, **kwargs) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 596e7006..e2d62188 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -911,6 +911,23 @@ register_template( ) +register_template( + name="dots_ocr", + format_user=StringFormatter(slots=["<|user|>{{content}}<|endofuser|><|assistant|>"]), + format_assistant=StringFormatter(slots=["{{content}}<|endofassistant|>"]), + format_system=StringFormatter(slots=["<|system|>{{content}}<|endofsystem|>\n"]), + stop_words=["<|endofassistant|>"], + efficient_eos=True, + mm_plugin=get_mm_plugin( + name="qwen2_vl", + image_token="<|imgpad|>", + video_token="<|vidpad|>", + start_token="<|img|>", + end_token="<|endofimg|>", + ), +) + + register_template( name="empty", format_assistant=StringFormatter(slots=["{{content}}"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 9f87f60b..ed7811e7 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -601,6 +601,18 @@ register_model_group( ) +register_model_group( + models={ + "dots.ocr": { + DownloadSource.DEFAULT: "rednote-hilab/dots.ocr", + DownloadSource.MODELSCOPE: "rednote-hilab/dots.ocr", + }, + }, + template="dots_ocr", + multimodal=True, +) + + register_model_group( models={ "ERNIE-4.5-21B-A3B-Thinking": { diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index ebca1d52..e5c39280 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -199,6 +199,15 @@ def patch_target_modules( return target_modules +_register_composite_model( + model_type="dots_ocr", + projector_key="vision_tower.merger", + vision_model_keys=["vision_tower"], + language_model_keys=["model", "lm_head"], + lora_conflict_keys=["merger"], +) + + _register_composite_model( model_type="gemma3", )