mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 02:00:36 +08:00
[data] fix qwen omni plugin (#9204)
Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
@@ -162,7 +162,7 @@ def load_model(
|
||||
load_class = AutoModelForVision2Seq
|
||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
||||
load_class = AutoModelForSeq2SeqLM
|
||||
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
|
||||
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen omni
|
||||
load_class = AutoModelForTextToWaveform
|
||||
else:
|
||||
load_class = AutoModelForCausalLM
|
||||
@@ -171,8 +171,8 @@ def load_model(
|
||||
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
|
||||
else:
|
||||
model = load_class.from_pretrained(**init_kwargs)
|
||||
if getattr(model.config, "model_type", None) == "qwen2_5_omni":
|
||||
model = model.thinker # use part of Omni model
|
||||
if getattr(model.config, "model_type", None) in ["qwen2_5_omni", "qwen3_omni_moe"]:
|
||||
model = getattr(model, "thinker")
|
||||
|
||||
if model_args.mixture_of_depths == "convert":
|
||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||
|
||||
Reference in New Issue
Block a user