[model] fix non-packing batch (bsz>1) for Qwen3.5 with flash attention (#10529)

This commit is contained in:
jiaqiw09
2026-05-30 21:41:41 +08:00
committed by GitHub
parent 01398eb18d
commit 7d719182c9

View File

@@ -162,8 +162,14 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
if position_ids is not None and position_ids.ndim == 3:
position_ids = position_ids[0]
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0] if position_ids is not None else None
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
if position_ids is not None and batch_size == 1:
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
else:
cu_seqlens = None
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
# standard causal-conv1d path that the upstream forward uses.