From c5a08291f4f1498735ae2237ed166bc5e93bc66f Mon Sep 17 00:00:00 2001 From: Kingsley Date: Tue, 1 Jul 2025 22:37:24 +0800 Subject: [PATCH] [model] add gemma3n (#8509) --- src/llamafactory/data/mm_plugin.py | 33 +++++++++++++++++++- src/llamafactory/data/template.py | 16 ++++++++++ src/llamafactory/extras/constants.py | 24 ++++++++++++++ src/llamafactory/model/model_utils/visual.py | 7 +++++ src/llamafactory/model/patcher.py | 3 ++ 5 files changed, 82 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 21ff1304..9da81682 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -388,7 +388,7 @@ class MMPluginMixin: return_tensors="pt", ) ) - mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts + mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None) # prevent conflicts return mm_inputs @@ -509,6 +509,36 @@ class Gemma3Plugin(BasePlugin): return mm_inputs +class Gemma3nPlugin(Gemma3Plugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + boi_token: str = getattr(processor, "boi_token") + full_image_sequence: str = getattr(processor, "full_image_sequence") + full_audio_sequence: str = getattr(processor, "full_audio_sequence") + image_str = full_image_sequence if self.expand_mm_tokens else boi_token + audio_str = full_audio_sequence if self.expand_mm_tokens else boi_token + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, image_str, 1) + + while AUDIO_PLACEHOLDER in content: + content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1) + + return messages + + @dataclass class InternVLPlugin(BasePlugin): @override @@ -1845,6 +1875,7 @@ PLUGINS = { "base": BasePlugin, "gemma3": Gemma3Plugin, "glm4v": GLM4VPlugin, + "gemma3n": Gemma3nPlugin, "intern_vl": InternVLPlugin, "kimi_vl": KimiVLPlugin, "llama4": Llama4Plugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 5ffa186f..d8fae09d 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -984,6 +984,22 @@ register_template( ) +register_template( + name="gemma3n", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + replace_eos=True, + mm_plugin=get_mm_plugin("gemma3n", image_token="", audio_token=""), + template_class=Llama2Template, +) + + register_template( name="glm4", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 1729dcab..10b54813 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -802,6 +802,30 @@ register_model_group( ) +register_model_group( + models={ + "Gemma-3n-E2B": { + DownloadSource.DEFAULT: "google/gemma-3n-E2B", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E2B", + }, + "Gemma-3n-E4B": { + DownloadSource.DEFAULT: "google/gemma-3n-E4B", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E4B", + }, + "Gemma-3n-E2B-Instruct": { + DownloadSource.DEFAULT: "google/gemma-3n-E2B-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E2B-it", + }, + "Gemma-3n-E4B-Instruct": { + DownloadSource.DEFAULT: "google/gemma-3n-E4B-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E4B-it", + }, + }, + template="gemma3n", + multimodal=True, +) + + register_model_group( models={ "GLM-4-9B": { diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index f12a56ec..f66c2e76 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -204,6 +204,13 @@ _register_composite_model( ) +_register_composite_model( + model_type="gemma3n", + vision_model_keys=["vision_tower", "audio_tower"], + lora_conflict_keys=["timm_model", "subsample_conv_projection"], +) + + # copied from qwen2vl _register_composite_model( model_type="glm4v", diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 4bf1d21d..d6075db4 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -178,6 +178,9 @@ def patch_model( resize_embedding_layer(model, tokenizer) if is_trainable: + if getattr(model.config, "model_type", None) == "gemma3n": + setattr(model_args, "disable_gradient_checkpointing", True) + prepare_model_for_training(model, model_args) autocast_projector_dtype(model, model_args) add_z3_leaf_module(model)