mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
[model] add dots ocr (#9176)
This commit is contained in:
parent
800934b507
commit
80fe3a172d
@ -1397,6 +1397,9 @@ class Qwen2AudioPlugin(BasePlugin):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen2VLPlugin(BasePlugin):
|
class Qwen2VLPlugin(BasePlugin):
|
||||||
|
start_token: str = "<|vision_start|>"
|
||||||
|
end_token: str = "<|vision_end|>"
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||||
image = super()._preprocess_image(image, **kwargs)
|
image = super()._preprocess_image(image, **kwargs)
|
||||||
@ -1512,14 +1515,14 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
|
IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
|
||||||
)
|
)
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
|
VIDEO_PLACEHOLDER, f"{self.start_token}{self.video_token * video_seqlen}{self.end_token}", 1
|
||||||
)
|
)
|
||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
|
|
||||||
@ -1907,9 +1910,10 @@ def get_mm_plugin(
|
|||||||
image_token: Optional[str] = None,
|
image_token: Optional[str] = None,
|
||||||
video_token: Optional[str] = None,
|
video_token: Optional[str] = None,
|
||||||
audio_token: Optional[str] = None,
|
audio_token: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
) -> "BasePlugin":
|
) -> "BasePlugin":
|
||||||
r"""Get plugin for multimodal inputs."""
|
r"""Get plugin for multimodal inputs."""
|
||||||
if name not in PLUGINS:
|
if name not in PLUGINS:
|
||||||
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||||
|
|
||||||
return PLUGINS[name](image_token, video_token, audio_token)
|
return PLUGINS[name](image_token, video_token, audio_token, **kwargs)
|
||||||
|
@ -911,6 +911,23 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="dots_ocr",
|
||||||
|
format_user=StringFormatter(slots=["<|user|>{{content}}<|endofuser|><|assistant|>"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|endofassistant|>"]),
|
||||||
|
format_system=StringFormatter(slots=["<|system|>{{content}}<|endofsystem|>\n"]),
|
||||||
|
stop_words=["<|endofassistant|>"],
|
||||||
|
efficient_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(
|
||||||
|
name="qwen2_vl",
|
||||||
|
image_token="<|imgpad|>",
|
||||||
|
video_token="<|vidpad|>",
|
||||||
|
start_token="<|img|>",
|
||||||
|
end_token="<|endofimg|>",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="empty",
|
name="empty",
|
||||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||||
|
@ -601,6 +601,18 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"dots.ocr": {
|
||||||
|
DownloadSource.DEFAULT: "rednote-hilab/dots.ocr",
|
||||||
|
DownloadSource.MODELSCOPE: "rednote-hilab/dots.ocr",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="dots_ocr",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"ERNIE-4.5-21B-A3B-Thinking": {
|
"ERNIE-4.5-21B-A3B-Thinking": {
|
||||||
|
@ -199,6 +199,15 @@ def patch_target_modules(
|
|||||||
return target_modules
|
return target_modules
|
||||||
|
|
||||||
|
|
||||||
|
_register_composite_model(
|
||||||
|
model_type="dots_ocr",
|
||||||
|
projector_key="vision_tower.merger",
|
||||||
|
vision_model_keys=["vision_tower"],
|
||||||
|
language_model_keys=["model", "lm_head"],
|
||||||
|
lora_conflict_keys=["merger"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="gemma3",
|
model_type="gemma3",
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user