mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-28 11:14:18 +08:00
[data] fix qwen3omni moe model (#9501)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
parent
10a446e373
commit
d4e120423d
@ -14,9 +14,13 @@
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...extras.misc import check_version
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -25,6 +29,9 @@ if TYPE_CHECKING:
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
if is_transformers_version_greater_than("4.57.0"):
|
||||
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe
|
||||
|
||||
|
||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None:
|
||||
check_version("deepspeed>=0.13.0")
|
||||
@ -175,3 +182,66 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
|
||||
|
||||
elif model_type == "jetmoe":
|
||||
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
|
||||
|
||||
class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
|
||||
# gating
|
||||
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextMLP(
|
||||
config, intermediate_size=config.moe_intermediate_size
|
||||
)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
# Calculate the routing weights for all experts
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
|
||||
# Retain the weight of the top_k and reset the rest of the expert rights to 0 (instead of retaining only top_k experts)
|
||||
top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
# Initialize the all-zero weight matrix (same shape as all experts)
|
||||
full_routing_weights = torch.zeros_like(routing_weights)
|
||||
# Only the weight of top_k experts is retained, and the weight of the rest of the experts remains at 0
|
||||
full_routing_weights.scatter_(1, top_k_indices, top_k_weights)
|
||||
|
||||
# Normalized top_k weights (keep the original logic consistent)
|
||||
if self.norm_topk_prob:
|
||||
# Calculate the sum of the weights top_k each row (for normalization)
|
||||
top_k_sum = full_routing_weights.sum(dim=-1, keepdim=True)
|
||||
# Avoid dividing by zero
|
||||
top_k_sum = torch.clamp(top_k_sum, min=1e-9)
|
||||
full_routing_weights /= top_k_sum
|
||||
|
||||
# Convert back to the input data type
|
||||
full_routing_weights = full_routing_weights.to(hidden_states.dtype)
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
# Go through all the experts (not just the selected ones)
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
# Get the weight of the current expert (inactive expert has a weight of 0 here)
|
||||
expert_weights = full_routing_weights[:, expert_idx, None] # shape: (batch*seq, 1)
|
||||
# All samples participate in the calculations of the current expert, the weight may be equal to 0
|
||||
current_hidden_states = expert_layer(hidden_states) * expert_weights
|
||||
# Add-up to all expert outputs (experts with a weight of 0 do not affect the result)
|
||||
final_hidden_states += current_hidden_states
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
@ -43,10 +43,20 @@ if TYPE_CHECKING:
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
if is_transformers_version_greater_than("4.57.0"):
|
||||
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
|
||||
if is_transformers_version_greater_than("4.57.0"):
|
||||
from .model_utils.moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||
|
||||
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||
|
||||
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
@ -136,6 +146,9 @@ def patch_config(
|
||||
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
||||
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen3_omni_moe":
|
||||
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()
|
||||
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user