From 988231026acfdf711ea2514a7e47b122ad8de5b6 Mon Sep 17 00:00:00 2001 From: ancv Date: Sun, 16 Jun 2024 02:25:47 +0700 Subject: [PATCH] update packing with sdpa and eager attention mode Former-commit-id: 238f5c3d99809c6ae2571b59bdce8d8ea3c700b9 --- src/llamafactory/extras/constants.py | 15 ++ src/llamafactory/model/model_utils/packing.py | 204 +++++++++++------- src/llamafactory/train/sft/workflow.py | 2 +- 3 files changed, 148 insertions(+), 73 deletions(-) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 7d96fb5f..d70922c1 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -66,6 +66,21 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} +SUPPORTED_CLASS_FOR_MULTIPACK = [ + "llama", + "mistral", + "mixtral", + "qwen2", + "qwen2_moe", + "falcon", + "phi", + "phi3", + "gemma", + "gemmoe", + "starcoder2", + "jamba" +] + V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 9b7359be..ce156728 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -12,7 +12,14 @@ import importlib import transformers from accelerate import init_empty_weights from transformers import AutoConfig, AutoModelForCausalLM +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.utils import is_torch_bf16_gpu_available + from ...extras.logging import get_logger +from ...extras.constants import SUPPORTED_CLASS_FOR_MULTIPACK if TYPE_CHECKING: from transformers import PretrainedConfig @@ -20,19 +27,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments, DataArguments -SUPPORTED_MULTIPACK_MODEL_TYPES = [ - "llama", - "mistral", - "mixtral", - "qwen2", - "qwen2_moe", - "falcon", - "phi", - "phi3", - "gemma", - "gemmoe", - "starcoder2", -] +logger = get_logger(__name__) @torch.jit.script @@ -67,6 +62,64 @@ def get_unpad_data(attention_mask: torch.Tensor): max_seqlen_in_batch, ) +def mask_2d_to_4d( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None +): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + This expansion handles packed sequences so that sequences share the same attention mask integer value + when they attend to each other within that sequence. + This expansion transforms the mask to lower triangular form to prevent future peeking. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + mask = mask.unsqueeze(1).unsqueeze(2) + mask = mask.expand(bsz, 1, tgt_len, src_len) + + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + binary_mask = torch.where( + mask != 0, + torch.tensor(1, device=mask.device).to(dtype), + torch.tensor(0, device=mask.device).to(dtype), + ) + + # Create a block-diagonal mask. + # we multiply by the binary mask so that 0's in the original mask are correctly excluded + zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask + + # Now let's create a lower triangular mask of ones that will zero out the upper triangular part + lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( + mask.device + ) + + # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask + masked_zero_one_mask = zero_one_mask * lower_triangular_ones + + return masked_zero_one_mask + + +def patched_prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + *args, +): + dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 + return _prepare_4d_causal_attention_mask( + mask_2d_to_4d(attention_mask, dtype=dtype), + *args, + ) + + +def patched_prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + *args, +): + dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 + return _prepare_4d_causal_attention_mask_for_sdpa( + mask_2d_to_4d(attention_mask, dtype=dtype), + *args, + ) + def set_module_name(model, name, value): if "." in name: @@ -169,57 +222,65 @@ def load_balancing_loss_func( return overall_loss * num_experts -def patch_for_multipack(model_type, model_name=None): - if model_type == "llama": - transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data +def patch_for_multipack(model_type, model_name, attn_implementation): + if attn_implementation == "flash_attention_2": + if model_type == "llama": + transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "mistral": + transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "mixtral": + transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access + load_balancing_loss_func + ) + elif model_type == "qwen2": + transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "qwen2_moe": + transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access + load_balancing_loss_func + ) + elif model_type == "falcon": + transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "phi": + transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "phi3": + transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "gemma": + transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "starcoder2": + transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "gemmoe": + patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") + elif model_type == "jamba": + patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") + else: + transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + patched_prepare_4d_causal_attention_mask_for_sdpa ) - elif model_type == "mistral": - transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data + transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + patched_prepare_4d_causal_attention_mask ) - elif model_type == "mixtral": - transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access - load_balancing_loss_func - ) - elif model_type == "qwen2": - transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "qwen2_moe": - transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access - load_balancing_loss_func - ) - elif model_type == "falcon": - transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "phi": - transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "phi3": - transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "gemma": - transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "starcoder2": - transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "gemmoe": - patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") - elif model_type == "jamba": - patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") def patch_remote(model_name, config_name, modeling_name): @@ -231,20 +292,19 @@ def patch_remote(model_name, config_name, modeling_name): modeling_arch = importlib.import_module(module_name) modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access + # check exist load_balancing_loss_func for moe model + if hasattr(modeling_arch, "load_balancing_loss_func"): + modeling_arch.load_balancing_loss_func = load_balancing_loss_func -def configure_packing(config: "PretrainedConfig") -> None: + +def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None: if getattr(config, "model_type", None) == "internlm2": # special case for custom models - attn_implementation = getattr(config, "attn_implementation", None) + attn_implementation = getattr(config, "attn_implementation", "") else: - attn_implementation = getattr(config, "_attn_implementation", None) + attn_implementation = getattr(config, "_attn_implementation", "") - if attn_implementation != "flash_attention_2": - raise ValueError("Efficient packing only supports for flash_attention_2. Please set config `flash_attn` is fa2") - - logger = get_logger(__name__) - - if getattr(config, "model_type", None) in SUPPORTED_MULTIPACK_MODEL_TYPES: - patch_for_multipack(getattr(config, "model_type", None)) + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_MULTIPACK: + patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation) logger.info("Using packing sequences without cross-contamination attention for efficient training.") else: raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False") \ No newline at end of file diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index d4965393..d7c29743 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -34,7 +34,7 @@ def run_sft( model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) if data_args.efficient_packing: - configure_packing(model.config) + configure_packing(model.config, model_args) if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation