mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[model] add dots ocr (#9176)
This commit is contained in:
		
							parent
							
								
									800934b507
								
							
						
					
					
						commit
						80fe3a172d
					
				@ -1397,6 +1397,9 @@ class Qwen2AudioPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
    start_token: str = "<|vision_start|>"
 | 
			
		||||
    end_token: str = "<|vision_end|>"
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
 | 
			
		||||
        image = super()._preprocess_image(image, **kwargs)
 | 
			
		||||
@ -1512,14 +1515,14 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
 | 
			
		||||
                    IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
 | 
			
		||||
                )
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
 | 
			
		||||
                    VIDEO_PLACEHOLDER, f"{self.start_token}{self.video_token * video_seqlen}{self.end_token}", 1
 | 
			
		||||
                )
 | 
			
		||||
                num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
@ -1907,9 +1910,10 @@ def get_mm_plugin(
 | 
			
		||||
    image_token: Optional[str] = None,
 | 
			
		||||
    video_token: Optional[str] = None,
 | 
			
		||||
    audio_token: Optional[str] = None,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> "BasePlugin":
 | 
			
		||||
    r"""Get plugin for multimodal inputs."""
 | 
			
		||||
    if name not in PLUGINS:
 | 
			
		||||
        raise ValueError(f"Multimodal plugin `{name}` not found.")
 | 
			
		||||
 | 
			
		||||
    return PLUGINS[name](image_token, video_token, audio_token)
 | 
			
		||||
    return PLUGINS[name](image_token, video_token, audio_token, **kwargs)
 | 
			
		||||
 | 
			
		||||
@ -911,6 +911,23 @@ register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="dots_ocr",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|user|>{{content}}<|endofuser|><|assistant|>"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<|endofassistant|>"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|system|>{{content}}<|endofsystem|>\n"]),
 | 
			
		||||
    stop_words=["<|endofassistant|>"],
 | 
			
		||||
    efficient_eos=True,
 | 
			
		||||
    mm_plugin=get_mm_plugin(
 | 
			
		||||
        name="qwen2_vl",
 | 
			
		||||
        image_token="<|imgpad|>",
 | 
			
		||||
        video_token="<|vidpad|>",
 | 
			
		||||
        start_token="<|img|>",
 | 
			
		||||
        end_token="<|endofimg|>",
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="empty",
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}"]),
 | 
			
		||||
 | 
			
		||||
@ -601,6 +601,18 @@ register_model_group(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "dots.ocr": {
 | 
			
		||||
            DownloadSource.DEFAULT: "rednote-hilab/dots.ocr",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "rednote-hilab/dots.ocr",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="dots_ocr",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "ERNIE-4.5-21B-A3B-Thinking": {
 | 
			
		||||
 | 
			
		||||
@ -199,6 +199,15 @@ def patch_target_modules(
 | 
			
		||||
        return target_modules
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="dots_ocr",
 | 
			
		||||
    projector_key="vision_tower.merger",
 | 
			
		||||
    vision_model_keys=["vision_tower"],
 | 
			
		||||
    language_model_keys=["model", "lm_head"],
 | 
			
		||||
    lora_conflict_keys=["merger"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="gemma3",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user