[model] support gemma3 (#7273)

This commit is contained in:
hoshi-hiyouga
2025-03-13 01:35:23 +08:00
committed by GitHub
parent e6159ad730
commit 4b9d8da5a4
9 changed files with 356 additions and 274 deletions

View File

@@ -19,6 +19,7 @@ import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM,
AutoModelForVision2Seq,
AutoProcessor,
@@ -72,7 +73,6 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
@@ -94,7 +94,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
patch_tokenizer(tokenizer, model_args)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args)
patch_processor(processor, tokenizer, model_args)
except Exception as e:
logger.debug(f"Processor was not found: {e}.")
processor = None
@@ -141,9 +141,11 @@ def load_model(
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM
else:
load_class = AutoModelForCausalLM