mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 04:38:53 +08:00
[model] fix non-packing batch (bsz>1) for Qwen3.5 with flash attention (#10529)
This commit is contained in:
@@ -162,8 +162,14 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
||||
if position_ids is not None and position_ids.ndim == 3:
|
||||
position_ids = position_ids[0]
|
||||
|
||||
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
|
||||
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0] if position_ids is not None else None
|
||||
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
|
||||
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
|
||||
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
|
||||
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
|
||||
if position_ids is not None and batch_size == 1:
|
||||
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
|
||||
else:
|
||||
cu_seqlens = None
|
||||
|
||||
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
|
||||
# standard causal-conv1d path that the upstream forward uses.
|
||||
|
||||
Reference in New Issue
Block a user