[data] fix qwen omni plugin (#9204)

Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
Yaowei Zheng
2025-09-28 01:02:29 +08:00
committed by GitHub
parent 0761a4448f
commit 6ffebe5ff7
15 changed files with 292 additions and 210 deletions

View File

@@ -162,7 +162,7 @@ def load_model(
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen omni
load_class = AutoModelForTextToWaveform
else:
load_class = AutoModelForCausalLM
@@ -171,8 +171,8 @@ def load_model(
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
else:
model = load_class.from_pretrained(**init_kwargs)
if getattr(model.config, "model_type", None) == "qwen2_5_omni":
model = model.thinker # use part of Omni model
if getattr(model.config, "model_type", None) in ["qwen2_5_omni", "qwen3_omni_moe"]:
model = getattr(model, "thinker")
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)