diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index baf84066..8a66621d 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -14,9 +14,13 @@ from typing import TYPE_CHECKING, Union +import torch +from torch import nn +from torch.nn import functional as F from transformers.integrations import is_deepspeed_zero3_enabled from ...extras.misc import check_version +from ...extras.packages import is_transformers_version_greater_than if TYPE_CHECKING: @@ -25,6 +29,9 @@ if TYPE_CHECKING: from ...hparams import ModelArguments +if is_transformers_version_greater_than("4.57.0"): + from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe + def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None: check_version("deepspeed>=0.13.0") @@ -175,3 +182,66 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t elif model_type == "jetmoe": setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) + + +class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [ + modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextMLP( + config, intermediate_size=config.moe_intermediate_size + ) + for _ in range(self.num_experts) + ] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + # Calculate the routing weights for all experts + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + + # Retain the weight of the top_k and reset the rest of the expert rights to 0 (instead of retaining only top_k experts) + top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1) + # Initialize the all-zero weight matrix (same shape as all experts) + full_routing_weights = torch.zeros_like(routing_weights) + # Only the weight of top_k experts is retained, and the weight of the rest of the experts remains at 0 + full_routing_weights.scatter_(1, top_k_indices, top_k_weights) + + # Normalized top_k weights (keep the original logic consistent) + if self.norm_topk_prob: + # Calculate the sum of the weights top_k each row (for normalization) + top_k_sum = full_routing_weights.sum(dim=-1, keepdim=True) + # Avoid dividing by zero + top_k_sum = torch.clamp(top_k_sum, min=1e-9) + full_routing_weights /= top_k_sum + + # Convert back to the input data type + full_routing_weights = full_routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # Go through all the experts (not just the selected ones) + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + # Get the weight of the current expert (inactive expert has a weight of 0 here) + expert_weights = full_routing_weights[:, expert_idx, None] # shape: (batch*seq, 1) + # All samples participate in the calculations of the current expert, the weight may be equal to 0 + current_hidden_states = expert_layer(hidden_states) * expert_weights + # Add-up to all expert outputs (experts with a weight of 0 do not affect the result) + final_hidden_states += current_hidden_states + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index cbf9aea3..fa2ac832 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -43,10 +43,20 @@ if TYPE_CHECKING: from ..hparams import ModelArguments +if is_transformers_version_greater_than("4.57.0"): + from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe + logger = logging.get_logger(__name__) +def patch_qwen3_omni_moe_thinker_text_sparse_moe_block(): + if is_transformers_version_greater_than("4.57.0"): + from .model_utils.moe import Qwen3OmniMoeThinkerTextSparseMoeBlock + + modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock + + def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None: if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) @@ -136,6 +146,9 @@ def patch_config( if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"): raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.") + if getattr(config, "model_type", None) == "qwen3_omni_moe": + patch_qwen3_omni_moe_thinker_text_sparse_moe_block() + # deepspeed zero3 is not compatible with low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())