[data] fix qwen3omni moe model (#9501)

Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
浮梦 2025-11-18 13:43:22 +08:00 committed by GitHub
parent 10a446e373
commit d4e120423d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 0 deletions

View File

@ -14,9 +14,13 @@
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import torch
from torch import nn
from torch.nn import functional as F
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.misc import check_version from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING: if TYPE_CHECKING:
@ -25,6 +29,9 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
if is_transformers_version_greater_than("4.57.0"):
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None: def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None:
check_version("deepspeed>=0.13.0") check_version("deepspeed>=0.13.0")
@ -175,3 +182,66 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
elif model_type == "jetmoe": elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
# gating
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = nn.ModuleList(
[
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextMLP(
config, intermediate_size=config.moe_intermediate_size
)
for _ in range(self.num_experts)
]
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
# Calculate the routing weights for all experts
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
# Retain the weight of the top_k and reset the rest of the expert rights to 0 (instead of retaining only top_k experts)
top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
# Initialize the all-zero weight matrix (same shape as all experts)
full_routing_weights = torch.zeros_like(routing_weights)
# Only the weight of top_k experts is retained, and the weight of the rest of the experts remains at 0
full_routing_weights.scatter_(1, top_k_indices, top_k_weights)
# Normalized top_k weights (keep the original logic consistent)
if self.norm_topk_prob:
# Calculate the sum of the weights top_k each row (for normalization)
top_k_sum = full_routing_weights.sum(dim=-1, keepdim=True)
# Avoid dividing by zero
top_k_sum = torch.clamp(top_k_sum, min=1e-9)
full_routing_weights /= top_k_sum
# Convert back to the input data type
full_routing_weights = full_routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# Go through all the experts (not just the selected ones)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
# Get the weight of the current expert (inactive expert has a weight of 0 here)
expert_weights = full_routing_weights[:, expert_idx, None] # shape: (batch*seq, 1)
# All samples participate in the calculations of the current expert, the weight may be equal to 0
current_hidden_states = expert_layer(hidden_states) * expert_weights
# Add-up to all expert outputs (experts with a weight of 0 do not affect the result)
final_hidden_states += current_hidden_states
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

View File

@ -43,10 +43,20 @@ if TYPE_CHECKING:
from ..hparams import ModelArguments from ..hparams import ModelArguments
if is_transformers_version_greater_than("4.57.0"):
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
if is_transformers_version_greater_than("4.57.0"):
from .model_utils.moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None: def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
@ -136,6 +146,9 @@ def patch_config(
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"): if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.") raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
if getattr(config, "model_type", None) == "qwen3_omni_moe":
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()
# deepspeed zero3 is not compatible with low_cpu_mem_usage # deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())