mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-31 10:42:50 +08:00
[model] add gemma3n (#8509)
This commit is contained in:
parent
544b7dc2ed
commit
c5a08291f4
@ -388,7 +388,7 @@ class MMPluginMixin:
|
||||
return_tensors="pt",
|
||||
)
|
||||
)
|
||||
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
|
||||
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None) # prevent conflicts
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@ -509,6 +509,36 @@ class Gemma3Plugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Gemma3nPlugin(Gemma3Plugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
boi_token: str = getattr(processor, "boi_token")
|
||||
full_image_sequence: str = getattr(processor, "full_image_sequence")
|
||||
full_audio_sequence: str = getattr(processor, "full_audio_sequence")
|
||||
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
|
||||
audio_str = full_audio_sequence if self.expand_mm_tokens else boi_token
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
content = content.replace(IMAGE_PLACEHOLDER, image_str, 1)
|
||||
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class InternVLPlugin(BasePlugin):
|
||||
@override
|
||||
@ -1845,6 +1875,7 @@ PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"gemma3": Gemma3Plugin,
|
||||
"glm4v": GLM4VPlugin,
|
||||
"gemma3n": Gemma3nPlugin,
|
||||
"intern_vl": InternVLPlugin,
|
||||
"kimi_vl": KimiVLPlugin,
|
||||
"llama4": Llama4Plugin,
|
||||
|
@ -984,6 +984,22 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="gemma3n",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<end_of_turn>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin("gemma3n", image_token="<image_soft_token>", audio_token="<audio_soft_token>"),
|
||||
template_class=Llama2Template,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="glm4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
|
@ -802,6 +802,30 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Gemma-3n-E2B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3n-E2B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E2B",
|
||||
},
|
||||
"Gemma-3n-E4B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3n-E4B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E4B",
|
||||
},
|
||||
"Gemma-3n-E2B-Instruct": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3n-E2B-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E2B-it",
|
||||
},
|
||||
"Gemma-3n-E4B-Instruct": {
|
||||
DownloadSource.DEFAULT: "google/gemma-3n-E4B-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E4B-it",
|
||||
},
|
||||
},
|
||||
template="gemma3n",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-4-9B": {
|
||||
|
@ -204,6 +204,13 @@ _register_composite_model(
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="gemma3n",
|
||||
vision_model_keys=["vision_tower", "audio_tower"],
|
||||
lora_conflict_keys=["timm_model", "subsample_conv_projection"],
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen2vl
|
||||
_register_composite_model(
|
||||
model_type="glm4v",
|
||||
|
@ -178,6 +178,9 @@ def patch_model(
|
||||
resize_embedding_layer(model, tokenizer)
|
||||
|
||||
if is_trainable:
|
||||
if getattr(model.config, "model_type", None) == "gemma3n":
|
||||
setattr(model_args, "disable_gradient_checkpointing", True)
|
||||
|
||||
prepare_model_for_training(model, model_args)
|
||||
autocast_projector_dtype(model, model_args)
|
||||
add_z3_leaf_module(model)
|
||||
|
Loading…
x
Reference in New Issue
Block a user