[train] add qwen35 patch for neat_packing (#10436)

This commit is contained in:
Kingsley
2026-04-27 00:31:49 +08:00
committed by GitHub
parent e0bc3c1971
commit 79c8332e4c
2 changed files with 190 additions and 2 deletions

View File

@@ -60,6 +60,191 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
def _check_fla_dependencies() -> None:
"""Check that the FLA dependencies required for varlen GDN forwarding are available.
Requires ``flash-linear-attention >= 0.4.1`` (which exposes the varlen
``causal_conv1d`` under ``fla.modules.convolution`` and the
``chunk_gated_delta_rule`` / ``fused_recurrent_gated_delta_rule`` kernels
under ``fla.ops.gated_delta_rule``). Raises ``ImportError`` with an
actionable message otherwise.
"""
try:
from fla.modules.convolution import causal_conv1d # noqa: F401
from fla.ops.gated_delta_rule import ( # noqa: F401
chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule,
)
except ImportError as exc:
raise ImportError(
"Qwen3.5 packing-seq forwarding requires `flash-linear-attention>=0.4.1` "
"(provides `fla.modules.convolution.causal_conv1d` and "
"`fla.ops.gated_delta_rule.{chunk,fused_recurrent}_gated_delta_rule`). "
"Please install/upgrade it."
) from exc
def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
"""
if is_transformers_version_greater_than("5.2.0"):
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states
from torch.nn import functional as F
from transformers.modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids
_check_fla_dependencies()
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
def _patched_decoder_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values=None,
cache_position: torch.LongTensor | None = None,
**kwargs,
) -> torch.FloatTensor:
"""Decoder layer forward that passes position_ids through to linear attention."""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
position_ids=position_ids, # passing position_ids to linear attention
)
elif self.layer_type == "full_attention":
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits)
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
# gdn forward (training only, cache_params is always None)
def _patch_gdn_forward(
self,
hidden_states: torch.Tensor,
cache_params=None,
cache_position: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
):
# @kuangdd fix: here attention_mask is None
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape
# Qwen3.5 VL passes 3-D MRoPE position_ids ([axes, B, T]); collapse to [B, T].
if position_ids is not None and position_ids.ndim == 3:
position_ids = position_ids[0]
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
cu_seqlens = (
prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
if position_ids is not None
else None
)
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
# standard causal-conv1d path that the upstream forward uses.
mixed_qkv = self.in_proj_qkv(hidden_states)
z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)
# FLA's causal_conv1d returns (out, final_state); we don't use the state here.
mixed_qkv, _ = fla_causal_conv1d(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
cu_seqlens=cu_seqlens,
)
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim,
self.key_dim,
self.value_dim,
],
dim=-1,
)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
beta = b.sigmoid()
# If the model is loaded in fp16, without the .float() here, A might be -inf
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
core_attn_out, _ = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
**({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}),
)
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
output = self.out_proj(core_attn_out)
return output
if model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
Qwen3_5DecoderLayer.forward = _patched_decoder_forward
Qwen3_5GatedDeltaNet.forward = _patch_gdn_forward
elif model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeDecoderLayer,
Qwen3_5MoeGatedDeltaNet,
)
Qwen3_5MoeDecoderLayer.forward = _patched_decoder_forward
Qwen3_5MoeGatedDeltaNet.forward = _patch_gdn_forward
logger.info_rank0("Patched Qwen3.5 decoder forward to support cu_seqlens input only patch when do training.")
def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
original_forward = model.forward
@@ -232,6 +417,9 @@ def patch_model(
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model)
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2":
patch_qwen3_5_forward(model)
if not model_args.use_unsloth:
print_attn_implementation(model.config)