mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[model] add gemma3n (#8509)
This commit is contained in:
		
							parent
							
								
									cbb65567a9
								
							
						
					
					
						commit
						e9f70daabe
					
				@ -388,7 +388,7 @@ class MMPluginMixin:
 | 
			
		||||
                    return_tensors="pt",
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask")  # prevent conflicts
 | 
			
		||||
            mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None)  # prevent conflicts
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
@ -509,6 +509,36 @@ class Gemma3Plugin(BasePlugin):
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Gemma3nPlugin(Gemma3Plugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        self._validate_messages(messages, images, videos, audios)
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        boi_token: str = getattr(processor, "boi_token")
 | 
			
		||||
        full_image_sequence: str = getattr(processor, "full_image_sequence")
 | 
			
		||||
        full_audio_sequence: str = getattr(processor, "full_audio_sequence")
 | 
			
		||||
        image_str = full_image_sequence if self.expand_mm_tokens else boi_token
 | 
			
		||||
        audio_str = full_audio_sequence if self.expand_mm_tokens else boi_token
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, image_str, 1)
 | 
			
		||||
 | 
			
		||||
            while AUDIO_PLACEHOLDER in content:
 | 
			
		||||
                content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class InternVLPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
@ -1845,6 +1875,7 @@ PLUGINS = {
 | 
			
		||||
    "base": BasePlugin,
 | 
			
		||||
    "gemma3": Gemma3Plugin,
 | 
			
		||||
    "glm4v": GLM4VPlugin,
 | 
			
		||||
    "gemma3n": Gemma3nPlugin,
 | 
			
		||||
    "intern_vl": InternVLPlugin,
 | 
			
		||||
    "kimi_vl": KimiVLPlugin,
 | 
			
		||||
    "llama4": Llama4Plugin,
 | 
			
		||||
 | 
			
		||||
@ -984,6 +984,22 @@ register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="gemma3n",
 | 
			
		||||
    format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["{{content}}\n\n"]),
 | 
			
		||||
    format_observation=StringFormatter(
 | 
			
		||||
        slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    stop_words=["<end_of_turn>"],
 | 
			
		||||
    replace_eos=True,
 | 
			
		||||
    mm_plugin=get_mm_plugin("gemma3n", image_token="<image_soft_token>", audio_token="<audio_soft_token>"),
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="glm4",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
 | 
			
		||||
 | 
			
		||||
@ -802,6 +802,30 @@ register_model_group(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Gemma-3n-E2B": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3n-E2B",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E2B",
 | 
			
		||||
        },
 | 
			
		||||
        "Gemma-3n-E4B": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3n-E4B",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E4B",
 | 
			
		||||
        },
 | 
			
		||||
        "Gemma-3n-E2B-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3n-E2B-it",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E2B-it",
 | 
			
		||||
        },
 | 
			
		||||
        "Gemma-3n-E4B-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3n-E4B-it",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3n-E4B-it",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="gemma3n",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "GLM-4-9B": {
 | 
			
		||||
 | 
			
		||||
@ -204,6 +204,13 @@ _register_composite_model(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="gemma3n",
 | 
			
		||||
    vision_model_keys=["vision_tower", "audio_tower"],
 | 
			
		||||
    lora_conflict_keys=["timm_model", "subsample_conv_projection"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from qwen2vl
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="glm4v",
 | 
			
		||||
 | 
			
		||||
@ -178,6 +178,9 @@ def patch_model(
 | 
			
		||||
        resize_embedding_layer(model, tokenizer)
 | 
			
		||||
 | 
			
		||||
    if is_trainable:
 | 
			
		||||
        if getattr(model.config, "model_type", None) == "gemma3n":
 | 
			
		||||
            setattr(model_args, "disable_gradient_checkpointing", True)
 | 
			
		||||
 | 
			
		||||
        prepare_model_for_training(model, model_args)
 | 
			
		||||
        autocast_projector_dtype(model, model_args)
 | 
			
		||||
        add_z3_leaf_module(model)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user