mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-24 23:02:49 +08:00
update packing with sdpa and eager attention mode
Former-commit-id: 238f5c3d99809c6ae2571b59bdce8d8ea3c700b9
This commit is contained in:
parent
9d9f8c6531
commit
988231026a
@ -66,6 +66,21 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
|||||||
|
|
||||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||||
|
|
||||||
|
SUPPORTED_CLASS_FOR_MULTIPACK = [
|
||||||
|
"llama",
|
||||||
|
"mistral",
|
||||||
|
"mixtral",
|
||||||
|
"qwen2",
|
||||||
|
"qwen2_moe",
|
||||||
|
"falcon",
|
||||||
|
"phi",
|
||||||
|
"phi3",
|
||||||
|
"gemma",
|
||||||
|
"gemmoe",
|
||||||
|
"starcoder2",
|
||||||
|
"jamba"
|
||||||
|
]
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
|
||||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
||||||
|
@ -12,7 +12,14 @@ import importlib
|
|||||||
import transformers
|
import transformers
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
from ...extras.constants import SUPPORTED_CLASS_FOR_MULTIPACK
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -20,19 +27,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments, DataArguments
|
from ...hparams import ModelArguments, DataArguments
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
logger = get_logger(__name__)
|
||||||
"llama",
|
|
||||||
"mistral",
|
|
||||||
"mixtral",
|
|
||||||
"qwen2",
|
|
||||||
"qwen2_moe",
|
|
||||||
"falcon",
|
|
||||||
"phi",
|
|
||||||
"phi3",
|
|
||||||
"gemma",
|
|
||||||
"gemmoe",
|
|
||||||
"starcoder2",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@ -67,6 +62,64 @@ def get_unpad_data(attention_mask: torch.Tensor):
|
|||||||
max_seqlen_in_batch,
|
max_seqlen_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def mask_2d_to_4d(
|
||||||
|
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
|
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||||
|
when they attend to each other within that sequence.
|
||||||
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||||
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||||
|
|
||||||
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||||
|
binary_mask = torch.where(
|
||||||
|
mask != 0,
|
||||||
|
torch.tensor(1, device=mask.device).to(dtype),
|
||||||
|
torch.tensor(0, device=mask.device).to(dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a block-diagonal mask.
|
||||||
|
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||||
|
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||||
|
|
||||||
|
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||||
|
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||||
|
mask.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||||
|
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||||
|
|
||||||
|
return masked_zero_one_mask
|
||||||
|
|
||||||
|
|
||||||
|
def patched_prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
*args,
|
||||||
|
):
|
||||||
|
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||||
|
return _prepare_4d_causal_attention_mask(
|
||||||
|
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
*args,
|
||||||
|
):
|
||||||
|
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||||
|
return _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_module_name(model, name, value):
|
def set_module_name(model, name, value):
|
||||||
if "." in name:
|
if "." in name:
|
||||||
@ -169,57 +222,65 @@ def load_balancing_loss_func(
|
|||||||
return overall_loss * num_experts
|
return overall_loss * num_experts
|
||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type, model_name=None):
|
def patch_for_multipack(model_type, model_name, attn_implementation):
|
||||||
if model_type == "llama":
|
if attn_implementation == "flash_attention_2":
|
||||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
if model_type == "llama":
|
||||||
get_unpad_data
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "mistral":
|
||||||
|
transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "mixtral":
|
||||||
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access
|
||||||
|
load_balancing_loss_func
|
||||||
|
)
|
||||||
|
elif model_type == "qwen2":
|
||||||
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "qwen2_moe":
|
||||||
|
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access
|
||||||
|
load_balancing_loss_func
|
||||||
|
)
|
||||||
|
elif model_type == "falcon":
|
||||||
|
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "phi":
|
||||||
|
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "phi3":
|
||||||
|
transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "gemma":
|
||||||
|
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "starcoder2":
|
||||||
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "gemmoe":
|
||||||
|
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||||
|
elif model_type == "jamba":
|
||||||
|
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
||||||
|
else:
|
||||||
|
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||||
|
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||||
)
|
)
|
||||||
elif model_type == "mistral":
|
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||||
transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access
|
patched_prepare_4d_causal_attention_mask
|
||||||
get_unpad_data
|
|
||||||
)
|
)
|
||||||
elif model_type == "mixtral":
|
|
||||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access
|
|
||||||
load_balancing_loss_func
|
|
||||||
)
|
|
||||||
elif model_type == "qwen2":
|
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "qwen2_moe":
|
|
||||||
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access
|
|
||||||
load_balancing_loss_func
|
|
||||||
)
|
|
||||||
elif model_type == "falcon":
|
|
||||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "phi":
|
|
||||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "phi3":
|
|
||||||
transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "gemma":
|
|
||||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "starcoder2":
|
|
||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "gemmoe":
|
|
||||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
|
||||||
elif model_type == "jamba":
|
|
||||||
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
|
||||||
|
|
||||||
|
|
||||||
def patch_remote(model_name, config_name, modeling_name):
|
def patch_remote(model_name, config_name, modeling_name):
|
||||||
@ -231,20 +292,19 @@ def patch_remote(model_name, config_name, modeling_name):
|
|||||||
modeling_arch = importlib.import_module(module_name)
|
modeling_arch = importlib.import_module(module_name)
|
||||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
||||||
|
|
||||||
|
# check exist load_balancing_loss_func for moe model
|
||||||
|
if hasattr(modeling_arch, "load_balancing_loss_func"):
|
||||||
|
modeling_arch.load_balancing_loss_func = load_balancing_loss_func
|
||||||
|
|
||||||
def configure_packing(config: "PretrainedConfig") -> None:
|
|
||||||
|
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||||
attn_implementation = getattr(config, "attn_implementation", None)
|
attn_implementation = getattr(config, "attn_implementation", "")
|
||||||
else:
|
else:
|
||||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
attn_implementation = getattr(config, "_attn_implementation", "")
|
||||||
|
|
||||||
if attn_implementation != "flash_attention_2":
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_MULTIPACK:
|
||||||
raise ValueError("Efficient packing only supports for flash_attention_2. Please set config `flash_attn` is fa2")
|
patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
if getattr(config, "model_type", None) in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
|
||||||
patch_for_multipack(getattr(config, "model_type", None))
|
|
||||||
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
|
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False")
|
raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False")
|
@ -34,7 +34,7 @@ def run_sft(
|
|||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
|
||||||
if data_args.efficient_packing:
|
if data_args.efficient_packing:
|
||||||
configure_packing(model.config)
|
configure_packing(model.config, model_args)
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation
|
tokenizer.padding_side = "left" # use left-padding in generation
|
||||||
|
Loading…
x
Reference in New Issue
Block a user