From 79c8332e4ca2880df474f7b1475a74b8dab4010e Mon Sep 17 00:00:00 2001 From: Kingsley Date: Mon, 27 Apr 2026 00:31:49 +0800 Subject: [PATCH] [train] add qwen35 patch for neat_packing (#10436) --- src/llamafactory/data/collator.py | 4 +- src/llamafactory/model/patcher.py | 188 ++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 677045cd0..5dd157d84 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -471,8 +471,8 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): def __post_init__(self): super().__post_init__() if self.neat_packing and self.attn_implementation == "flash_attention_2": - if self.model is not None and getattr(self.model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe", "gpt_oss"]: - raise ValueError("Neat packing is not supported for qwen3_5, qwen3_5_moe, gpt_oss models for now.") + if self.model is not None and getattr(self.model.config, "model_type", None) in ["gemma4", "gpt_oss"]: + raise ValueError("Neat packing is not supported for gemma4, gpt_oss models for now.") @staticmethod def _unpad_packed_features(features: dict[str, Any]) -> None: diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 375216f16..684d26df4 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -60,6 +60,191 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block(): modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock +def _check_fla_dependencies() -> None: + """Check that the FLA dependencies required for varlen GDN forwarding are available. + + Requires ``flash-linear-attention >= 0.4.1`` (which exposes the varlen + ``causal_conv1d`` under ``fla.modules.convolution`` and the + ``chunk_gated_delta_rule`` / ``fused_recurrent_gated_delta_rule`` kernels + under ``fla.ops.gated_delta_rule``). Raises ``ImportError`` with an + actionable message otherwise. + """ + try: + from fla.modules.convolution import causal_conv1d # noqa: F401 + from fla.ops.gated_delta_rule import ( # noqa: F401 + chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, + ) + except ImportError as exc: + raise ImportError( + "Qwen3.5 packing-seq forwarding requires `flash-linear-attention>=0.4.1` " + "(provides `fla.modules.convolution.causal_conv1d` and " + "`fla.ops.gated_delta_rule.{chunk,fused_recurrent}_gated_delta_rule`). " + "Please install/upgrade it." + ) from exc + + +def patch_qwen3_5_forward(model: "PreTrainedModel") -> None: + """Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training. + + Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py. + """ + if is_transformers_version_greater_than("5.2.0"): + from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states + + from torch.nn import functional as F + from transformers.modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids + + _check_fla_dependencies() + from fla.modules.convolution import causal_conv1d as fla_causal_conv1d + from fla.ops.gated_delta_rule import chunk_gated_delta_rule + + def _patched_decoder_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values=None, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> torch.FloatTensor: + """Decoder layer forward that passes position_ids through to linear attention.""" + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + position_ids=position_ids, # passing position_ids to linear attention + ) + elif self.layer_type == "full_attention": + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits) + hidden_states, _ = hidden_states + + hidden_states = residual + hidden_states + + return hidden_states + + # gdn forward (training only, cache_params is always None) + def _patch_gdn_forward( + self, + hidden_states: torch.Tensor, + cache_params=None, + cache_position: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + ): + # @kuangdd fix: here attention_mask is None + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + batch_size, seq_len, _ = hidden_states.shape + + # Qwen3.5 VL passes 3-D MRoPE position_ids ([axes, B, T]); collapse to [B, T]. + if position_ids is not None and position_ids.ndim == 3: + position_ids = position_ids[0] + + # `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety. + cu_seqlens = ( + prepare_fa_kwargs_from_position_ids(position_ids)[0][0] + if position_ids is not None + else None + ) + + # FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the + # standard causal-conv1d path that the upstream forward uses. + mixed_qkv = self.in_proj_qkv(hidden_states) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + # FLA's causal_conv1d returns (out, final_state); we don't use the state here. + mixed_qkv, _ = fla_causal_conv1d( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + cu_seqlens=cu_seqlens, + ) + + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + core_attn_out, _ = chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + **({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}), + ) + + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + output = self.out_proj(core_attn_out) + + return output + + if model.config.architectures[0] == "Qwen3_5ForConditionalGeneration": + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet + Qwen3_5DecoderLayer.forward = _patched_decoder_forward + Qwen3_5GatedDeltaNet.forward = _patch_gdn_forward + elif model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration": + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeDecoderLayer, + Qwen3_5MoeGatedDeltaNet, + ) + Qwen3_5MoeDecoderLayer.forward = _patched_decoder_forward + Qwen3_5MoeGatedDeltaNet.forward = _patch_gdn_forward + + logger.info_rank0("Patched Qwen3.5 decoder forward to support cu_seqlens input only patch when do training.") + + def patch_youtu_vl_model(model: "PreTrainedModel") -> None: original_forward = model.forward @@ -232,6 +417,9 @@ def patch_model( autocast_projector_dtype(model, model_args) add_z3_leaf_module(model) + if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2": + patch_qwen3_5_forward(model) + if not model_args.use_unsloth: print_attn_implementation(model.config)