mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-30 02:30:35 +08:00
[model] support gemma3 (#7273)
This commit is contained in:
@@ -19,6 +19,7 @@ import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
@@ -72,7 +73,6 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
@@ -94,7 +94,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
patch_tokenizer(tokenizer, model_args)
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
patch_processor(processor, config, tokenizer, model_args)
|
||||
patch_processor(processor, tokenizer, model_args)
|
||||
except Exception as e:
|
||||
logger.debug(f"Processor was not found: {e}.")
|
||||
processor = None
|
||||
@@ -141,9 +141,11 @@ def load_model(
|
||||
if model_args.mixture_of_depths == "load":
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
else:
|
||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
|
||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
|
||||
load_class = AutoModelForVision2Seq
|
||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():
|
||||
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
|
||||
load_class = AutoModelForImageTextToText
|
||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
||||
load_class = AutoModelForSeq2SeqLM
|
||||
else:
|
||||
load_class = AutoModelForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user