[model] add gemma3n (#8509)

This commit is contained in:
Kingsley 2025-07-01 22:37:24 +08:00 committed by GitHub
parent 544b7dc2ed
commit c5a08291f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 82 additions and 1 deletions

View File

@ -388,7 +388,7 @@ class MMPluginMixin:
return_tensors="pt",
)
)
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None) # prevent conflicts
return mm_inputs
@ -509,6 +509,36 @@ class Gemma3Plugin(BasePlugin):
return mm_inputs
class Gemma3nPlugin(Gemma3Plugin):
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)
boi_token: str = getattr(processor, "boi_token")
full_image_sequence: str = getattr(processor, "full_image_sequence")
full_audio_sequence: str = getattr(processor, "full_audio_sequence")
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
audio_str = full_audio_sequence if self.expand_mm_tokens else boi_token
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, image_str, 1)
while AUDIO_PLACEHOLDER in content:
content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
return messages
@dataclass
class InternVLPlugin(BasePlugin):
@override
@ -1845,6 +1875,7 @@ PLUGINS = {
"base": BasePlugin,
"gemma3": Gemma3Plugin,
"glm4v": GLM4VPlugin,
"gemma3n": Gemma3nPlugin,
"intern_vl": InternVLPlugin,
"kimi_vl": KimiVLPlugin,
"llama4": Llama4Plugin,

View File

@ -984,6 +984,22 @@ register_template(
)
register_template(
name="gemma3n",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
replace_eos=True,
mm_plugin=get_mm_plugin("gemma3n", image_token="<image_soft_token>", audio_token="<audio_soft_token>"),
template_class=Llama2Template,
)
register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),

View File

@ -802,6 +802,30 @@ register_model_group(
)
register_model_group(
models={
"Gemma-3n-E2B": {
DownloadSource.DEFAULT: "google/gemma-3n-E2B",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E2B",
},
"Gemma-3n-E4B": {
DownloadSource.DEFAULT: "google/gemma-3n-E4B",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E4B",
},
"Gemma-3n-E2B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3n-E2B-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E2B-it",
},
"Gemma-3n-E4B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3n-E4B-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E4B-it",
},
},
template="gemma3n",
multimodal=True,
)
register_model_group(
models={
"GLM-4-9B": {

View File

@ -204,6 +204,13 @@ _register_composite_model(
)
_register_composite_model(
model_type="gemma3n",
vision_model_keys=["vision_tower", "audio_tower"],
lora_conflict_keys=["timm_model", "subsample_conv_projection"],
)
# copied from qwen2vl
_register_composite_model(
model_type="glm4v",

View File

@ -178,6 +178,9 @@ def patch_model(
resize_embedding_layer(model, tokenizer)
if is_trainable:
if getattr(model.config, "model_type", None) == "gemma3n":
setattr(model_args, "disable_gradient_checkpointing", True)
prepare_model_for_training(model, model_args)
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model)