Update packing.py

Former-commit-id: a36e8f2dd50e0f1c589457a7e785fdbc905d561d
This commit is contained in:
hoshi-hiyouga 2024-07-03 23:36:01 +08:00 committed by GitHub
parent 13cec0cc2f
commit 51c75985b8

View File

@ -257,7 +257,7 @@ def load_balancing_loss_func(
return overall_loss * num_experts
def patch_for_multipack(model_type, model_name, attn_implementation):
def patch_for_block_diag_attn(model_type, model_name, attn_implementation):
if attn_implementation == "flash_attention_2":
if model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
@ -305,10 +305,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
else:
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
@ -318,7 +314,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
)
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return