mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-27 18:29:08 +08:00
[train] add qwen35 patch for neat_packing (#10436)
This commit is contained in:
@@ -60,6 +60,191 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
|
||||
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||
|
||||
|
||||
def _check_fla_dependencies() -> None:
|
||||
"""Check that the FLA dependencies required for varlen GDN forwarding are available.
|
||||
|
||||
Requires ``flash-linear-attention >= 0.4.1`` (which exposes the varlen
|
||||
``causal_conv1d`` under ``fla.modules.convolution`` and the
|
||||
``chunk_gated_delta_rule`` / ``fused_recurrent_gated_delta_rule`` kernels
|
||||
under ``fla.ops.gated_delta_rule``). Raises ``ImportError`` with an
|
||||
actionable message otherwise.
|
||||
"""
|
||||
try:
|
||||
from fla.modules.convolution import causal_conv1d # noqa: F401
|
||||
from fla.ops.gated_delta_rule import ( # noqa: F401
|
||||
chunk_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Qwen3.5 packing-seq forwarding requires `flash-linear-attention>=0.4.1` "
|
||||
"(provides `fla.modules.convolution.causal_conv1d` and "
|
||||
"`fla.ops.gated_delta_rule.{chunk,fused_recurrent}_gated_delta_rule`). "
|
||||
"Please install/upgrade it."
|
||||
) from exc
|
||||
|
||||
|
||||
def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
||||
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
|
||||
|
||||
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
|
||||
"""
|
||||
if is_transformers_version_greater_than("5.2.0"):
|
||||
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states
|
||||
|
||||
from torch.nn import functional as F
|
||||
from transformers.modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids
|
||||
|
||||
_check_fla_dependencies()
|
||||
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
|
||||
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||
|
||||
def _patched_decoder_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values=None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""Decoder layer forward that passes position_ids through to linear attention."""
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
if self.layer_type == "linear_attention":
|
||||
hidden_states = self.linear_attn(
|
||||
hidden_states=hidden_states,
|
||||
cache_params=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids, # passing position_ids to linear attention
|
||||
)
|
||||
elif self.layer_type == "full_attention":
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits)
|
||||
hidden_states, _ = hidden_states
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
# gdn forward (training only, cache_params is always None)
|
||||
def _patch_gdn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache_params=None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
):
|
||||
# @kuangdd fix: here attention_mask is None
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# Qwen3.5 VL passes 3-D MRoPE position_ids ([axes, B, T]); collapse to [B, T].
|
||||
if position_ids is not None and position_ids.ndim == 3:
|
||||
position_ids = position_ids[0]
|
||||
|
||||
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
|
||||
cu_seqlens = (
|
||||
prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
|
||||
if position_ids is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
|
||||
# standard causal-conv1d path that the upstream forward uses.
|
||||
mixed_qkv = self.in_proj_qkv(hidden_states)
|
||||
|
||||
z = self.in_proj_z(hidden_states)
|
||||
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
|
||||
|
||||
b = self.in_proj_b(hidden_states)
|
||||
a = self.in_proj_a(hidden_states)
|
||||
|
||||
# FLA's causal_conv1d returns (out, final_state); we don't use the state here.
|
||||
mixed_qkv, _ = fla_causal_conv1d(
|
||||
x=mixed_qkv,
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
query, key, value = torch.split(
|
||||
mixed_qkv,
|
||||
[
|
||||
self.key_dim,
|
||||
self.key_dim,
|
||||
self.value_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
|
||||
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
|
||||
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
|
||||
|
||||
beta = b.sigmoid()
|
||||
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
||||
if self.num_v_heads // self.num_k_heads > 1:
|
||||
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
|
||||
core_attn_out, _ = chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
**({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}),
|
||||
)
|
||||
|
||||
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
|
||||
z = z.reshape(-1, self.head_v_dim)
|
||||
core_attn_out = self.norm(core_attn_out, z)
|
||||
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
|
||||
|
||||
output = self.out_proj(core_attn_out)
|
||||
|
||||
return output
|
||||
|
||||
if model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
|
||||
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
|
||||
Qwen3_5DecoderLayer.forward = _patched_decoder_forward
|
||||
Qwen3_5GatedDeltaNet.forward = _patch_gdn_forward
|
||||
elif model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
|
||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||
Qwen3_5MoeDecoderLayer,
|
||||
Qwen3_5MoeGatedDeltaNet,
|
||||
)
|
||||
Qwen3_5MoeDecoderLayer.forward = _patched_decoder_forward
|
||||
Qwen3_5MoeGatedDeltaNet.forward = _patch_gdn_forward
|
||||
|
||||
logger.info_rank0("Patched Qwen3.5 decoder forward to support cu_seqlens input only patch when do training.")
|
||||
|
||||
|
||||
def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
|
||||
original_forward = model.forward
|
||||
|
||||
@@ -232,6 +417,9 @@ def patch_model(
|
||||
autocast_projector_dtype(model, model_args)
|
||||
add_z3_leaf_module(model)
|
||||
|
||||
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2":
|
||||
patch_qwen3_5_forward(model)
|
||||
|
||||
if not model_args.use_unsloth:
|
||||
print_attn_implementation(model.config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user