[fix] fit neat_packing & mrope model packing (#10283)

Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Kingsley
2026-03-20 16:50:11 +08:00
committed by GitHub
parent d91d8af89e
commit 833f6027b1
15 changed files with 520 additions and 93 deletions

View File

@@ -37,7 +37,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
@@ -45,10 +44,6 @@ import torch.nn.functional as F
from ...extras import logging
if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = logging.get_logger(__name__)
@@ -105,13 +100,3 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "tor
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")

View File

@@ -24,7 +24,6 @@ import transformers.models
from transformers.activations import ACT2FN
from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
@@ -344,9 +343,7 @@ _register_composite_model(
model_type="qwen2_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"]
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
@@ -355,9 +352,7 @@ _register_composite_model(
model_type="qwen2_5_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"]
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)

View File

@@ -30,7 +30,6 @@ from .model_utils.embedding import resize_embedding_layer
from .model_utils.kv_cache import configure_kv_cache
from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model
@@ -142,7 +141,6 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
configure_packing(model_args, is_trainable)
configure_kv_cache(config, model_args, is_trainable)
if getattr(config, "model_type", None) == "qwen":