mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
[model] add Qwen2.5-Omni model (#7537)
* preserve image_sizes * preserve image_sizes * init plugin * support audio-text2text lora * nit * support image/video-text2text, audio-text2text * remove args * remove lines * add docs && nit * remove some comments * fix && add merge part script * add license
This commit is contained in:
@@ -21,6 +21,7 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForTextToWaveform,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
@@ -147,6 +148,8 @@ def load_model(
|
||||
load_class = AutoModelForImageTextToText
|
||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
||||
load_class = AutoModelForSeq2SeqLM
|
||||
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
|
||||
load_class = AutoModelForTextToWaveform
|
||||
else:
|
||||
load_class = AutoModelForCausalLM
|
||||
|
||||
@@ -154,6 +157,8 @@ def load_model(
|
||||
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
|
||||
else:
|
||||
model = load_class.from_pretrained(**init_kwargs)
|
||||
if load_class is AutoModelForTextToWaveform:
|
||||
model = model.thinker # use part of Omni model
|
||||
|
||||
if model_args.mixture_of_depths == "convert":
|
||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||
|
||||
Reference in New Issue
Block a user