mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 02:00:36 +08:00
[data] fix qwen3omni moe model (#9501)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
@@ -43,10 +43,20 @@ if TYPE_CHECKING:
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
if is_transformers_version_greater_than("4.57.0"):
|
||||
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
|
||||
if is_transformers_version_greater_than("4.57.0"):
|
||||
from .model_utils.moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||
|
||||
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||
|
||||
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
@@ -136,6 +146,9 @@ def patch_config(
|
||||
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
||||
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen3_omni_moe":
|
||||
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()
|
||||
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user