diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 91b801dc..20ea1417 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -465,6 +465,38 @@ class BasePlugin(MMPluginMixin): self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) +@dataclass +class ErnieVLPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + image_idx, video_idx = 0, 0 + for message in messages: + content = message["content"] + image_token = self.image_token or "<|image@placeholder|>" + video_token = self.video_token or "<|video@placeholder|>" + while IMAGE_PLACEHOLDER in content: + image_idx += 1 + content = content.replace( + IMAGE_PLACEHOLDER, f"Picture {image_idx}:<|IMAGE_START|>{image_token}<|IMAGE_END|>", 1 + ) + while VIDEO_PLACEHOLDER in content: + video_idx += 1 + content = content.replace( + VIDEO_PLACEHOLDER, f"Video {video_idx}:<|VIDEO_START|>{video_token}<|VIDEO_END|>", 1 + ) + message["content"] = content + return messages + @dataclass class Gemma3Plugin(BasePlugin): @@ -2039,6 +2071,7 @@ class VideoLlavaPlugin(BasePlugin): PLUGINS = { "base": BasePlugin, + "ernie_vl": ErnieVLPlugin, "gemma3": Gemma3Plugin, "glm4v": GLM4VPlugin, "gemma3n": Gemma3nPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 56e32dd2..604c4301 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -963,6 +963,19 @@ register_template( ) +register_template( + name="ernie_vl", + format_user=StringFormatter(slots=["User: {{content}}"]), + format_assistant=StringFormatter(slots=["\nAssistant: {{content}}<|end_of_sentence|>"]), + format_system=StringFormatter(slots=["{{content}}\n"]), + stop_words=["<|end_of_sentence|>"], + replace_eos=True, + replace_jinja_template=True, + template_class=ReasoningTemplate, + mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|image@placeholder|>", video_token="<|video@placeholder|>"), +) + + register_template( name="exaone", format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index c1137ac2..e03740a1 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -657,6 +657,26 @@ register_model_group( ) +register_model_group( + models={ + "ERNIE-4.5-VL-28B-A3B-PT": { + DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-PT", + DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT", + }, + "ERNIE-4.5-VL-28B-A3B-Thinking": { + DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-Thinking", + DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Thinking", + }, + "ERNIE-4.5-VL-424B-A47B-Base-PT": { + DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-424B-A47B-PT", + DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT", + }, + }, + template="ernie_vl", + multimodal=True, +) + + register_model_group( models={ "EXAONE-3.0-7.8B-Instruct": {