From 045eb155a2151f69b19b8ae9463a268232644dee Mon Sep 17 00:00:00 2001 From: ancv Date: Wed, 12 Jun 2024 11:56:01 +0700 Subject: [PATCH 01/12] implement efficient packing without cross-contamination attention Former-commit-id: b2c367bc61c2778dc359613dca496d9e134c2743 --- src/llamafactory/data/preprocess.py | 2 +- .../data/processors/supervised.py | 12 +- src/llamafactory/hparams/data_args.py | 4 + src/llamafactory/model/model_utils/packing.py | 250 ++++++++++++++++++ src/llamafactory/train/kto/workflow.py | 2 +- src/llamafactory/train/sft/workflow.py | 5 +- src/llamafactory/webui/components/train.py | 5 +- src/llamafactory/webui/locales.py | 14 + src/llamafactory/webui/runner.py | 1 + 9 files changed, 287 insertions(+), 8 deletions(-) create mode 100644 src/llamafactory/model/model_utils/packing.py diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 97789c39..cf207d7e 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -36,7 +36,7 @@ def get_preprocess_and_print_func( ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not training_args.predict_with_generate: - if data_args.packing: + if data_args.packing or data_args.efficient_packing: preprocess_func = partial( preprocess_packed_supervised_dataset, template=template, diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 19d60280..3406576b 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from transformers import ProcessorMixin from transformers.tokenization_utils import PreTrainedTokenizer - from ...hparams import DataArguments + from ...hparams import DataArguments, FinetuningArguments from ..template import Template @@ -140,11 +140,12 @@ def preprocess_packed_supervised_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) for knapsack in knapsacks: - packed_input_ids, packed_labels = [], [] - for length in knapsack: + packed_input_ids, packed_attention_mask, packed_labels = [], [], [] + for i, length in enumerate(knapsack): index = length2indexes[length].pop() packed_input_ids += batch_input_ids[index] packed_labels += batch_labels[index] + packed_attention_mask += [i+1]*len(batch_input_ids[index]) if len(packed_input_ids) < data_args.cutoff_len: pad_length = data_args.cutoff_len - len(packed_input_ids) @@ -155,7 +156,10 @@ def preprocess_packed_supervised_dataset( raise ValueError("The length of packed example should be identical to the cutoff length.") model_inputs["input_ids"].append(packed_input_ids) - model_inputs["attention_mask"].append([1] * data_args.cutoff_len) + if data_args.efficient_packing: + model_inputs["attention_mask"].append(packed_attention_mask) + else: + model_inputs["attention_mask"].append([1] * data_args.cutoff_len) model_inputs["labels"].append(packed_labels) return model_inputs diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 1e0cd08c..2081c6b4 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -84,6 +84,10 @@ class DataArguments: "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." }, ) + efficient_packing: Optional[bool] = field( + default=None, + metadata={"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."}, + ) tokenized_path: Optional[str] = field( default=None, metadata={"help": "Path to save or load the tokenized datasets."}, diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py new file mode 100644 index 00000000..fe718ebb --- /dev/null +++ b/src/llamafactory/model/model_utils/packing.py @@ -0,0 +1,250 @@ +# Copy from original implementation of src/axolotl/monkeypatch/multipack.py and src/axolotl/monkeypatch/utils.py from axolotl library with some changes +""" +Shared utils for the monkeypatches +""" +from typing import Optional, TYPE_CHECKING + +import torch +import torch.nn.functional as F + +import importlib + +import transformers +from accelerate import init_empty_weights +from transformers import AutoConfig, AutoModelForCausalLM +from ...extras.logging import get_logger + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments, DataArguments + + +SUPPORTED_MULTIPACK_MODEL_TYPES = [ + "llama", + "mistral", + "mixtral", + "qwen2", + "qwen2_moe", + "falcon", + "phi", + "phi3", + "gemma", + "gemmoe", + "starcoder2", +] + + +@torch.jit.script +def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: + max_num = int(torch.max(attention_mask).item()) + batch_size, _ = attention_mask.shape + counts = torch.zeros((batch_size, max_num), dtype=torch.int32) + + for i in range(1, max_num + 1): + mask = attention_mask == i + counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) + + result = counts.flatten() + nonzero_indices = torch.nonzero(result).squeeze(-1) + return result[nonzero_indices] + + +@torch.jit.script +def get_unpad_data(attention_mask: torch.Tensor): + device = attention_mask.device + seqlens_in_batch = get_max_seqlen_in_batch(attention_mask) + indices = torch.nonzero(attention_mask.flatten()).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = ( + F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + .to(device=device) + .detach() + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def set_module_name(model, name, value): + if "." in name: + parent_name = name.rsplit(".", 1)[0] + child_name = name[len(parent_name) + 1 :] + parent = model.get_submodule(parent_name) + else: + parent_name = "" + parent = model + child_name = name + + setattr(parent, child_name, value) + + +# Copy from original implementation of modeling_mixtral.py from transformers, Just change a little bit with new_attention_mask +def load_balancing_loss_func( + gate_logits: torch.Tensor, + num_experts: torch.Tensor = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + # ONLY ADD THIS LINE OF CODE, AND REPLACE attention_mask WITH new_attention_mask + new_attention_mask = (attention_mask != 0).int().to(attention_mask.device) + batch_size, sequence_length = new_attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + new_attention_mask[None, :, :, None, None] + .expand( + (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) + ) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + new_attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +def patch_for_multipack(model_type, model_name=None): + if model_type == "llama": + transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "mistral": + transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "mixtral": + transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access + load_balancing_loss_func + ) + elif model_type == "qwen2": + transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "qwen2_moe": + transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access + load_balancing_loss_func + ) + elif model_type == "falcon": + transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "phi": + transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "phi3": + transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "gemma": + transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "starcoder2": + transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "gemmoe": + patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") + elif model_type == "jamba": + patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") + + +def patch_remote(model_name, config_name, modeling_name): + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + # we need to load the model here in order for modeling_* to be available + with init_empty_weights(): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + module_name = model_config.__class__.__module__.replace(config_name, modeling_name) + modeling_arch = importlib.import_module(module_name) + modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access + + +def configure_packing(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + attn_implementation = getattr(config, "attn_implementation", None) + else: + attn_implementation = getattr(config, "_attn_implementation", None) + + if attn_implementation != "flash_attention_2": + raise ValueError("Efficient packing only supports for flash_attention_2. Please set config `flash_attn` is fa2" + " " + attn_implementation) + + logger = get_logger(__name__) + + if getattr(config, "model_type", None) in SUPPORTED_MULTIPACK_MODEL_TYPES: + patch_for_multipack(getattr(config, "model_type", None)) + logger.info("Using packing sequences without cross-contamination attention for efficient training.") + else: + raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False") \ No newline at end of file diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index c79b160b..f003e157 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -24,7 +24,7 @@ def run_kto( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) + dataset = get_dataset(model_args, data_args, training_args, finetuning_args, stage="kto", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = KTODataCollatorWithPadding( diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index f09b5173..d4965393 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push from .metric import ComputeMetrics from .trainer import CustomSeq2SeqTrainer - +from ...model.model_utils.packing import configure_packing if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -33,6 +33,9 @@ def run_sft( dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + if data_args.efficient_packing: + configure_packing(model.config) + if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 72dfc858..dccc8500 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -83,6 +83,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Column(): resize_vocab = gr.Checkbox() packing = gr.Checkbox() + efficient_packing = gr.Checkbox() with gr.Column(): upcast_layernorm = gr.Checkbox() @@ -101,6 +102,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: optim, resize_vocab, packing, + efficient_packing, upcast_layernorm, use_llama_pro, shift_attn, @@ -117,6 +119,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: optim=optim, resize_vocab=resize_vocab, packing=packing, + efficient_packing=efficient_packing, upcast_layernorm=upcast_layernorm, use_llama_pro=use_llama_pro, shift_attn=shift_attn, @@ -313,7 +316,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False) - training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False) + training_stage.change(change_stage, [training_stage], [dataset, packing, efficient_packing], queue=False) reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False) model_name.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index e30feab2..05cf3bed 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -494,6 +494,20 @@ LOCALES = { "info": "将序列打包为等长样本。", }, }, + "efficient_packing": { + "en": { + "label": "Pack sequences for efficient training", + "info": "Pack sequences into samples of fixed length without cross-contamination attention for efficient training.", + }, + "ru": { + "label": "Пакетные последовательности для эффективного обучения", + "info": "Упакуйте последовательности в образцы фиксированной длины без учета перекрестного загрязнения для эффективного обучения.", + }, + "zh": { + "label": "打包序列以实现高效训练", + "info": "为了提高训练效率,将序列打包成固定长度的样本,无需注意交叉污染。", + }, + }, "upcast_layernorm": { "en": { "label": "Upcast LayerNorm", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 35014628..852805da 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -120,6 +120,7 @@ class Runner: optim=get("train.optim"), resize_vocab=get("train.resize_vocab"), packing=get("train.packing"), + efficient_packing=get("train.efficient_packing"), upcast_layernorm=get("train.upcast_layernorm"), use_llama_pro=get("train.use_llama_pro"), shift_attn=get("train.shift_attn"), From 9d9f8c6531fdc581d9ac8c26047597341bbb2d01 Mon Sep 17 00:00:00 2001 From: ancv Date: Sat, 15 Jun 2024 23:00:55 +0700 Subject: [PATCH 02/12] remove some unused params Former-commit-id: 04315c3d92ecc25537e45d5807cb38bc290dcb16 --- src/llamafactory/data/processors/supervised.py | 2 +- src/llamafactory/model/model_utils/packing.py | 2 +- src/llamafactory/train/kto/workflow.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 3406576b..35640174 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from transformers import ProcessorMixin from transformers.tokenization_utils import PreTrainedTokenizer - from ...hparams import DataArguments, FinetuningArguments + from ...hparams import DataArguments from ..template import Template diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index fe718ebb..9b7359be 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -239,7 +239,7 @@ def configure_packing(config: "PretrainedConfig") -> None: attn_implementation = getattr(config, "_attn_implementation", None) if attn_implementation != "flash_attention_2": - raise ValueError("Efficient packing only supports for flash_attention_2. Please set config `flash_attn` is fa2" + " " + attn_implementation) + raise ValueError("Efficient packing only supports for flash_attention_2. Please set config `flash_attn` is fa2") logger = get_logger(__name__) diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index f003e157..c79b160b 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -24,7 +24,7 @@ def run_kto( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset = get_dataset(model_args, data_args, training_args, finetuning_args, stage="kto", **tokenizer_module) + dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = KTODataCollatorWithPadding( From 988231026acfdf711ea2514a7e47b122ad8de5b6 Mon Sep 17 00:00:00 2001 From: ancv Date: Sun, 16 Jun 2024 02:25:47 +0700 Subject: [PATCH 03/12] update packing with sdpa and eager attention mode Former-commit-id: 238f5c3d99809c6ae2571b59bdce8d8ea3c700b9 --- src/llamafactory/extras/constants.py | 15 ++ src/llamafactory/model/model_utils/packing.py | 204 +++++++++++------- src/llamafactory/train/sft/workflow.py | 2 +- 3 files changed, 148 insertions(+), 73 deletions(-) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 7d96fb5f..d70922c1 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -66,6 +66,21 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} +SUPPORTED_CLASS_FOR_MULTIPACK = [ + "llama", + "mistral", + "mixtral", + "qwen2", + "qwen2_moe", + "falcon", + "phi", + "phi3", + "gemma", + "gemmoe", + "starcoder2", + "jamba" +] + V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 9b7359be..ce156728 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -12,7 +12,14 @@ import importlib import transformers from accelerate import init_empty_weights from transformers import AutoConfig, AutoModelForCausalLM +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.utils import is_torch_bf16_gpu_available + from ...extras.logging import get_logger +from ...extras.constants import SUPPORTED_CLASS_FOR_MULTIPACK if TYPE_CHECKING: from transformers import PretrainedConfig @@ -20,19 +27,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments, DataArguments -SUPPORTED_MULTIPACK_MODEL_TYPES = [ - "llama", - "mistral", - "mixtral", - "qwen2", - "qwen2_moe", - "falcon", - "phi", - "phi3", - "gemma", - "gemmoe", - "starcoder2", -] +logger = get_logger(__name__) @torch.jit.script @@ -67,6 +62,64 @@ def get_unpad_data(attention_mask: torch.Tensor): max_seqlen_in_batch, ) +def mask_2d_to_4d( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None +): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + This expansion handles packed sequences so that sequences share the same attention mask integer value + when they attend to each other within that sequence. + This expansion transforms the mask to lower triangular form to prevent future peeking. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + mask = mask.unsqueeze(1).unsqueeze(2) + mask = mask.expand(bsz, 1, tgt_len, src_len) + + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + binary_mask = torch.where( + mask != 0, + torch.tensor(1, device=mask.device).to(dtype), + torch.tensor(0, device=mask.device).to(dtype), + ) + + # Create a block-diagonal mask. + # we multiply by the binary mask so that 0's in the original mask are correctly excluded + zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask + + # Now let's create a lower triangular mask of ones that will zero out the upper triangular part + lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( + mask.device + ) + + # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask + masked_zero_one_mask = zero_one_mask * lower_triangular_ones + + return masked_zero_one_mask + + +def patched_prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + *args, +): + dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 + return _prepare_4d_causal_attention_mask( + mask_2d_to_4d(attention_mask, dtype=dtype), + *args, + ) + + +def patched_prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + *args, +): + dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 + return _prepare_4d_causal_attention_mask_for_sdpa( + mask_2d_to_4d(attention_mask, dtype=dtype), + *args, + ) + def set_module_name(model, name, value): if "." in name: @@ -169,57 +222,65 @@ def load_balancing_loss_func( return overall_loss * num_experts -def patch_for_multipack(model_type, model_name=None): - if model_type == "llama": - transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data +def patch_for_multipack(model_type, model_name, attn_implementation): + if attn_implementation == "flash_attention_2": + if model_type == "llama": + transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "mistral": + transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "mixtral": + transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access + load_balancing_loss_func + ) + elif model_type == "qwen2": + transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "qwen2_moe": + transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access + load_balancing_loss_func + ) + elif model_type == "falcon": + transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "phi": + transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "phi3": + transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "gemma": + transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "starcoder2": + transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "gemmoe": + patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") + elif model_type == "jamba": + patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") + else: + transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + patched_prepare_4d_causal_attention_mask_for_sdpa ) - elif model_type == "mistral": - transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data + transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + patched_prepare_4d_causal_attention_mask ) - elif model_type == "mixtral": - transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access - load_balancing_loss_func - ) - elif model_type == "qwen2": - transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "qwen2_moe": - transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access - load_balancing_loss_func - ) - elif model_type == "falcon": - transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "phi": - transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "phi3": - transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "gemma": - transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "starcoder2": - transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "gemmoe": - patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") - elif model_type == "jamba": - patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") def patch_remote(model_name, config_name, modeling_name): @@ -231,20 +292,19 @@ def patch_remote(model_name, config_name, modeling_name): modeling_arch = importlib.import_module(module_name) modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access + # check exist load_balancing_loss_func for moe model + if hasattr(modeling_arch, "load_balancing_loss_func"): + modeling_arch.load_balancing_loss_func = load_balancing_loss_func -def configure_packing(config: "PretrainedConfig") -> None: + +def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None: if getattr(config, "model_type", None) == "internlm2": # special case for custom models - attn_implementation = getattr(config, "attn_implementation", None) + attn_implementation = getattr(config, "attn_implementation", "") else: - attn_implementation = getattr(config, "_attn_implementation", None) + attn_implementation = getattr(config, "_attn_implementation", "") - if attn_implementation != "flash_attention_2": - raise ValueError("Efficient packing only supports for flash_attention_2. Please set config `flash_attn` is fa2") - - logger = get_logger(__name__) - - if getattr(config, "model_type", None) in SUPPORTED_MULTIPACK_MODEL_TYPES: - patch_for_multipack(getattr(config, "model_type", None)) + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_MULTIPACK: + patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation) logger.info("Using packing sequences without cross-contamination attention for efficient training.") else: raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False") \ No newline at end of file diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index d4965393..d7c29743 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -34,7 +34,7 @@ def run_sft( model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) if data_args.efficient_packing: - configure_packing(model.config) + configure_packing(model.config, model_args) if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation From 5319447aa5f81355593f2a0a73e2fc39b5969912 Mon Sep 17 00:00:00 2001 From: ancv Date: Fri, 21 Jun 2024 00:45:06 +0700 Subject: [PATCH 04/12] move configure_packing to llamafactory.model.patcher and fix constants Former-commit-id: 770f75dc8363bfa284a72159ff8ad25ec9abe4e0 --- src/llamafactory/extras/constants.py | 2 +- src/llamafactory/hparams/data_args.py | 4 +++- src/llamafactory/model/loader.py | 5 +++-- src/llamafactory/model/model_utils/packing.py | 4 ++-- src/llamafactory/model/patcher.py | 9 +++++++-- src/llamafactory/train/sft/workflow.py | 4 ---- 6 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index d70922c1..466b1269 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -66,7 +66,7 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} -SUPPORTED_CLASS_FOR_MULTIPACK = [ +SUPPORTED_CLASS_EFFECIENT_PACKING = [ "llama", "mistral", "mixtral", diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 2081c6b4..d2d53ec8 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -86,7 +86,9 @@ class DataArguments: ) efficient_packing: Optional[bool] = field( default=None, - metadata={"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."}, + metadata={ + "help": "Whether or not to pack the sequences without cross-contamination attention for efficient training." + }, ) tokenized_path: Optional[str] = field( default=None, diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 697a04e7..026a09be 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -16,7 +16,7 @@ from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin - from ..hparams import FinetuningArguments, ModelArguments + from ..hparams import FinetuningArguments, ModelArguments, DataArguments logger = get_logger(__name__) @@ -104,6 +104,7 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig": def load_model( tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", + data_args: "DataArguments", finetuning_args: "FinetuningArguments", is_trainable: bool = False, add_valuehead: bool = False, @@ -113,7 +114,7 @@ def load_model( """ init_kwargs = _get_init_kwargs(model_args) config = load_config(model_args) - patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) + patch_config(config, tokenizer, model_args, data_args, finetuning_args, init_kwargs, is_trainable) model = None lazy_load = False diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index ce156728..606cd03b 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -19,7 +19,7 @@ from transformers.modeling_attn_mask_utils import ( from transformers.utils import is_torch_bf16_gpu_available from ...extras.logging import get_logger -from ...extras.constants import SUPPORTED_CLASS_FOR_MULTIPACK +from ...extras.constants import SUPPORTED_CLASS_EFFECIENT_PACKING if TYPE_CHECKING: from transformers import PretrainedConfig @@ -303,7 +303,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") else: attn_implementation = getattr(config, "_attn_implementation", "") - if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_MULTIPACK: + if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING: patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation) logger.info("Using packing sequences without cross-contamination attention for efficient training.") else: diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 87c92315..47591de6 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -19,13 +19,13 @@ from .model_utils.quantization import configure_quantization from .model_utils.rope import configure_rope from .model_utils.valuehead import prepare_valuehead_model from .model_utils.visual import autocast_projector_dtype, configure_visual_model - +from .model_utils.packing import configure_packing if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer from trl import AutoModelForCausalLMWithValueHead - from ..hparams import ModelArguments + from ..hparams import ModelArguments, DataArguments, FinetuningArguments logger = get_logger(__name__) @@ -40,6 +40,8 @@ def patch_config( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", + data_args: "DataArguments", + finetune_args: "FinetuningArguments", init_kwargs: Dict[str, Any], is_trainable: bool, ) -> None: @@ -81,6 +83,9 @@ def patch_config( if init_kwargs["device_map"] == "auto": init_kwargs["offload_folder"] = model_args.offload_folder + + if finetune_args.stage == "sft" and data_args.efficient_packing: + configure_packing(config, model_args) def patch_model( diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index d7c29743..f1e000bd 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -12,7 +12,6 @@ from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push from .metric import ComputeMetrics from .trainer import CustomSeq2SeqTrainer -from ...model.model_utils.packing import configure_packing if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -33,9 +32,6 @@ def run_sft( dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) - if data_args.efficient_packing: - configure_packing(model.config, model_args) - if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation From 7f42932957e1cee4681aad5f9200b8b829b230b0 Mon Sep 17 00:00:00 2001 From: ancv Date: Tue, 2 Jul 2024 18:37:55 +0700 Subject: [PATCH 05/12] move efficient_packing from data_args to model_args Former-commit-id: e8e13b09423dd08a31a3bde8f85833c6e5d43ee5 --- src/llamafactory/data/loader.py | 2 +- src/llamafactory/data/preprocess.py | 5 +++-- src/llamafactory/data/processors/supervised.py | 5 +++-- src/llamafactory/hparams/data_args.py | 6 ------ src/llamafactory/hparams/model_args.py | 6 ++++++ src/llamafactory/hparams/parser.py | 3 +++ src/llamafactory/model/loader.py | 5 ++--- src/llamafactory/model/patcher.py | 6 ++---- 8 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 8e7062db..5f116e4e 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -177,7 +177,7 @@ def get_dataset( with training_args.main_process_first(desc="pre-process dataset"): preprocess_func, print_function = get_preprocess_and_print_func( - data_args, training_args, stage, template, tokenizer, processor + data_args, model_args, training_args, stage, template, tokenizer, processor ) column_names = list(next(iter(dataset)).keys()) kwargs = {} diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 3a80900c..ae69e84e 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -29,12 +29,13 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments - from ..hparams import DataArguments + from ..hparams import DataArguments, ModelArguments from .template import Template def get_preprocess_and_print_func( data_args: "DataArguments", + model_args: "ModelArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", @@ -49,7 +50,7 @@ def get_preprocess_and_print_func( ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not training_args.predict_with_generate: - if data_args.packing or data_args.efficient_packing: + if data_args.packing or model_args.efficient_packing: preprocess_func = partial( preprocess_packed_supervised_dataset, template=template, diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 8ef55321..78811477 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -23,7 +23,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin - from ...hparams import DataArguments + from ...hparams import DataArguments, ModelArguments from ..template import Template @@ -125,6 +125,7 @@ def preprocess_packed_supervised_dataset( template: "Template", tokenizer: "PreTrainedTokenizer", data_args: "DataArguments", + model_args: "ModelArguments" ) -> Dict[str, List[List[int]]]: # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` @@ -176,7 +177,7 @@ def preprocess_packed_supervised_dataset( raise ValueError("The length of packed example should be identical to the cutoff length.") model_inputs["input_ids"].append(packed_input_ids) - if data_args.efficient_packing: + if model_args.efficient_packing: model_inputs["attention_mask"].append(packed_attention_mask) else: model_inputs["attention_mask"].append([1] * data_args.cutoff_len) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index e351fccf..880be84a 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -97,12 +97,6 @@ class DataArguments: "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." }, ) - efficient_packing: Optional[bool] = field( - default=None, - metadata={ - "help": "Whether or not to pack the sequences without cross-contamination attention for efficient training." - }, - ) tool_format: Optional[str] = field( default=None, metadata={"help": "Tool format to use for constructing function calling examples."}, diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 087c8c38..49503022 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -109,6 +109,12 @@ class ModelArguments: default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, ) + efficient_packing: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether or not to pack the sequences without cross-contamination attention for efficient training." + }, + ) mixture_of_depths: Optional[Literal["convert", "load"]] = field( default=None, metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 8b2ea4c1..507f7fef 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -170,6 +170,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.stage == "ppo" and model_args.shift_attn: raise ValueError("PPO training is incompatible with S^2-Attn.") + if finetuning_args.stage != "sft" and model_args.efficient_packing: + raise ValueError("`efficient_packing` cannot be set as True except SFT.") + if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: raise ValueError("Unsloth does not support lora reward model.") diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 43e65d52..fe700d53 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -31,7 +31,7 @@ from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin - from ..hparams import FinetuningArguments, ModelArguments, DataArguments + from ..hparams import FinetuningArguments, ModelArguments logger = get_logger(__name__) @@ -120,7 +120,6 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig": def load_model( tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", - data_args: "DataArguments", finetuning_args: "FinetuningArguments", is_trainable: bool = False, add_valuehead: bool = False, @@ -130,7 +129,7 @@ def load_model( """ init_kwargs = _get_init_kwargs(model_args) config = load_config(model_args) - patch_config(config, tokenizer, model_args, data_args, finetuning_args, init_kwargs, is_trainable) + patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) model = None lazy_load = False diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index f1831ced..2ddfd21a 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer from trl import AutoModelForCausalLMWithValueHead - from ..hparams import ModelArguments, DataArguments, FinetuningArguments + from ..hparams import ModelArguments logger = get_logger(__name__) @@ -54,8 +54,6 @@ def patch_config( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", - data_args: "DataArguments", - finetune_args: "FinetuningArguments", init_kwargs: Dict[str, Any], is_trainable: bool, ) -> None: @@ -104,7 +102,7 @@ def patch_config( if init_kwargs.get("device_map", None) == "auto": init_kwargs["offload_folder"] = model_args.offload_folder - if finetune_args.stage == "sft" and data_args.efficient_packing: + if model_args.efficient_packing: configure_packing(config, model_args) From 28c8e083f4dd5db693450e32b540a115d6704674 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 3 Jul 2024 23:05:39 +0800 Subject: [PATCH 06/12] test Former-commit-id: a4a1ddbcb987422cd04125ff3f36f8c739061b5c --- tests/data/test_collator.py | 55 +++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 tests/data/test_collator.py diff --git a/tests/data/test_collator.py b/tests/data/test_collator.py new file mode 100644 index 00000000..cb473d4c --- /dev/null +++ b/tests/data/test_collator.py @@ -0,0 +1,55 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from llamafactory.data.collator import prepare_4d_attention_mask + + +def test_4d_attention_mask(): + o = 0.0 + x = torch.finfo(torch.float16).min + attention_mask_with_indices = torch.tensor( + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ) + attention_mask_computed = prepare_4d_attention_mask(attention_mask_with_indices, torch.float16) + attention_mask_expected = torch.tensor( + [ + [ + [ + [o, x, x, x, x, x], + [o, o, x, x, x, x], + [x, x, o, x, x, x], + [x, x, o, o, x, x], + [x, x, o, o, o, x], + [x, x, x, x, x, x], + ] + ], + [ + [ + [o, x, x, x, x, x], + [x, o, x, x, x, x], + [x, o, o, x, x, x], + [x, x, x, o, x, x], + [x, x, x, o, o, x], + [x, x, x, o, o, o], + ] + ], + ], + dtype=torch.float16, + ) + assert torch.all(attention_mask_computed == attention_mask_expected) From b254df2d349ecb7e0e156a4efd530bd6cdf432ac Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 3 Jul 2024 23:13:49 +0800 Subject: [PATCH 07/12] update ui Former-commit-id: 7f770f6895f1e2e0b8e4f0b49088bfae096f6d3c --- src/llamafactory/webui/components/train.py | 17 ++++----- src/llamafactory/webui/locales.py | 42 ++++++++-------------- src/llamafactory/webui/runner.py | 5 ++- 3 files changed, 23 insertions(+), 41 deletions(-) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 4636050b..9f7e0d2a 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -95,12 +95,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): with gr.Column(): - resize_vocab = gr.Checkbox() packing = gr.Checkbox() - efficient_packing = gr.Checkbox() + neat_packing = gr.Checkbox() with gr.Column(): - upcast_layernorm = gr.Checkbox() + resize_vocab = gr.Checkbox() use_llama_pro = gr.Checkbox() with gr.Column(): @@ -114,10 +113,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: warmup_steps, neftune_alpha, optim, - resize_vocab, packing, - efficient_packing, - upcast_layernorm, + neat_packing, + resize_vocab, use_llama_pro, shift_attn, report_to, @@ -131,10 +129,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: warmup_steps=warmup_steps, neftune_alpha=neftune_alpha, optim=optim, - resize_vocab=resize_vocab, packing=packing, - efficient_packing=efficient_packing, - upcast_layernorm=upcast_layernorm, + neat_packing=neat_packing, + resize_vocab=resize_vocab, use_llama_pro=use_llama_pro, shift_attn=shift_attn, report_to=report_to, @@ -331,7 +328,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False) - training_stage.change(change_stage, [training_stage], [dataset, packing, efficient_packing], queue=False) + training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False) reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False) model_name.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 852b1b3c..affc832f 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -494,20 +494,6 @@ LOCALES = { "info": "使用的优化器:adamw_torch、adamw_8bit 或 adafactor。", }, }, - "resize_vocab": { - "en": { - "label": "Resize token embeddings", - "info": "Resize the tokenizer vocab and the embedding layers.", - }, - "ru": { - "label": "Изменение размера токенных эмбеддингов", - "info": "Изменить размер словаря токенизатора и слоев эмбеддинга.", - }, - "zh": { - "label": "更改词表大小", - "info": "更改分词器词表和嵌入层的大小。", - }, - }, "packing": { "en": { "label": "Pack sequences", @@ -522,32 +508,32 @@ LOCALES = { "info": "将序列打包为等长样本。", }, }, - "efficient_packing": { + "neat_packing": { "en": { - "label": "Pack sequences for efficient training", - "info": "Pack sequences into samples of fixed length without cross-contamination attention for efficient training.", + "label": "Use neat packing", + "info": "Avoid cross-attention between packed sequences.", }, "ru": { - "label": "Пакетные последовательности для эффективного обучения", - "info": "Упакуйте последовательности в образцы фиксированной длины без учета перекрестного загрязнения для эффективного обучения.", + "label": "Используйте аккуратную упаковку", + "info": "избегайте перекрестного внимания между упакованными последовательностями.", }, "zh": { - "label": "打包序列以实现高效训练", - "info": "为了提高训练效率,将序列打包成固定长度的样本,无需注意交叉污染。", + "label": "使用无污染打包", + "info": "避免打包后的序列产生交叉注意力。", }, }, - "upcast_layernorm": { + "resize_vocab": { "en": { - "label": "Upcast LayerNorm", - "info": "Upcast weights of layernorm in float32.", + "label": "Resize token embeddings", + "info": "Resize the tokenizer vocab and the embedding layers.", }, "ru": { - "label": "Приведение весов LayerNorm", - "info": "Приведение весов LayerNorm к float32.", + "label": "Изменение размера токенных эмбеддингов", + "info": "Изменить размер словаря токенизатора и слоев эмбеддинга.", }, "zh": { - "label": "缩放归一化层", - "info": "将归一化层权重缩放至 32 位精度。", + "label": "更改词表大小", + "info": "更改分词器词表和嵌入层的大小。", }, }, "use_llama_pro": { diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index ffec54e2..e23f4d15 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -138,10 +138,9 @@ class Runner: warmup_steps=get("train.warmup_steps"), neftune_noise_alpha=get("train.neftune_alpha") or None, optim=get("train.optim"), + packing=get("train.packing") or get("train.neat_packing"), + neat_packing=get("train.neat_packing"), resize_vocab=get("train.resize_vocab"), - packing=get("train.packing"), - efficient_packing=get("train.efficient_packing"), - upcast_layernorm=get("train.upcast_layernorm"), use_llama_pro=get("train.use_llama_pro"), shift_attn=get("train.shift_attn"), report_to="all" if get("train.report_to") else "none", From ff6fc666c1038255418e421626a4c541e0a13484 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 3 Jul 2024 23:18:58 +0800 Subject: [PATCH 08/12] update hparams Former-commit-id: 575a02a23d9b41d00ca6291d8a40b5bdb3cbeeec --- src/llamafactory/data/collator.py | 43 ++++++++++++++++++- src/llamafactory/data/loader.py | 2 +- src/llamafactory/data/preprocess.py | 5 +-- .../data/processors/supervised.py | 19 ++++---- src/llamafactory/hparams/data_args.py | 15 ++++--- src/llamafactory/hparams/model_args.py | 8 +--- src/llamafactory/hparams/parser.py | 7 +-- src/llamafactory/train/sft/workflow.py | 1 + 8 files changed, 72 insertions(+), 28 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index e4859ff5..0939925d 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -1,4 +1,7 @@ -# Copyright 2024 the LlamaFactory team. +# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team. +# +# This code is inspired by the OpenAccess AI Collective's axolotl library. +# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +22,44 @@ import torch from transformers import DataCollatorForSeq2Seq +def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": + r""" + Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), + while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. + + e.g. + ``` + [1, 1, 2, 2, 2, 0] + ``` + -> + ``` + [[ + [ + [o, x, x, x, x, x], + [o, o, x, x, x, x], + [x, x, o, x, x, x], + [x, x, o, o, x, x], + [x, x, o, o, o, x], + [x, x, o, x, x, x], + ] + ]] + ``` + where `o` equals to `0.0`, `x` equals to `min_dtype`. + """ + bsz, seq_len = attention_mask_with_indices.size() + min_dtype = torch.finfo(dtype).min + expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len) + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + padding_mask = torch.where(expanded_mask != 0, 1, 0) + # Create a block-diagonal mask. + attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask + # Use the lower triangular mask to zero out the upper triangular part + attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long)) + # Invert the attention mask. + attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype) + return attention_mask_4d + + @dataclass class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): r""" diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 5f116e4e..8e7062db 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -177,7 +177,7 @@ def get_dataset( with training_args.main_process_first(desc="pre-process dataset"): preprocess_func, print_function = get_preprocess_and_print_func( - data_args, model_args, training_args, stage, template, tokenizer, processor + data_args, training_args, stage, template, tokenizer, processor ) column_names = list(next(iter(dataset)).keys()) kwargs = {} diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index ae69e84e..9a8b97f3 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -29,13 +29,12 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments - from ..hparams import DataArguments, ModelArguments + from ..hparams import DataArguments from .template import Template def get_preprocess_and_print_func( data_args: "DataArguments", - model_args: "ModelArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", @@ -50,7 +49,7 @@ def get_preprocess_and_print_func( ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not training_args.predict_with_generate: - if data_args.packing or model_args.efficient_packing: + if data_args.packing: preprocess_func = partial( preprocess_packed_supervised_dataset, template=template, diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 78811477..747a0c1b 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -23,7 +23,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin - from ...hparams import DataArguments, ModelArguments + from ...hparams import DataArguments from ..template import Template @@ -125,7 +125,6 @@ def preprocess_packed_supervised_dataset( template: "Template", tokenizer: "PreTrainedTokenizer", data_args: "DataArguments", - model_args: "ModelArguments" ) -> Dict[str, List[List[int]]]: # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` @@ -161,26 +160,30 @@ def preprocess_packed_supervised_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) for knapsack in knapsacks: - packed_input_ids, packed_attention_mask, packed_labels = [], [], [] + packed_input_ids, packed_attention_masks, packed_labels = [], [], [] for i, length in enumerate(knapsack): index = length2indexes[length].pop() packed_input_ids += batch_input_ids[index] packed_labels += batch_labels[index] - packed_attention_mask += [i+1]*len(batch_input_ids[index]) + if data_args.neat_packing: + packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 + else: + packed_attention_masks += [1] * len(batch_input_ids[index]) if len(packed_input_ids) < data_args.cutoff_len: pad_length = data_args.cutoff_len - len(packed_input_ids) packed_input_ids += [tokenizer.pad_token_id] * pad_length packed_labels += [IGNORE_INDEX] * pad_length + if data_args.neat_packing: + packed_attention_masks += [0] * pad_length + else: + packed_attention_masks += [1] * pad_length # more efficient flash_attn if len(packed_input_ids) != data_args.cutoff_len: raise ValueError("The length of packed example should be identical to the cutoff length.") model_inputs["input_ids"].append(packed_input_ids) - if model_args.efficient_packing: - model_inputs["attention_mask"].append(packed_attention_mask) - else: - model_inputs["attention_mask"].append([1] * data_args.cutoff_len) + model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["labels"].append(packed_labels) return model_inputs diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 880be84a..38bbbb12 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -83,9 +83,7 @@ class DataArguments: ) ignore_pad_token_for_loss: bool = field( default=True, - metadata={ - "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation." - }, + metadata={"help": "Whether or not to ignore the tokens corresponding to the pad tokens in loss computation."}, ) val_size: float = field( default=0.0, @@ -93,9 +91,11 @@ class DataArguments: ) packing: Optional[bool] = field( default=None, - metadata={ - "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." - }, + metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, + ) + neat_packing: bool = field( + default=False, + metadata={"help": "Enable sequence packing without cross-attention."}, ) tool_format: Optional[str] = field( default=None, @@ -112,3 +112,6 @@ class DataArguments: if self.streaming and self.max_samples is not None: raise ValueError("`max_samples` is incompatible with `streaming`.") + + if self.neat_packing and not self.packing: + raise ValueError("`neat_packing` requires `packing` is True.") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 49503022..4ac47512 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -109,12 +109,6 @@ class ModelArguments: default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, ) - efficient_packing: Optional[bool] = field( - default=None, - metadata={ - "help": "Whether or not to pack the sequences without cross-contamination attention for efficient training." - }, - ) mixture_of_depths: Optional[Literal["convert", "load"]] = field( default=None, metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, @@ -232,6 +226,7 @@ class ModelArguments: self.compute_dtype: Optional["torch.dtype"] = None self.device_map: Optional[Union[str, Dict[str, Any]]] = None self.model_max_length: Optional[int] = None + self.block_diag_attn: bool = False if self.split_special_tokens and self.use_fast_tokenizer: raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") @@ -259,4 +254,5 @@ class ModelArguments: new_arg.compute_dtype = old_arg.compute_dtype new_arg.device_map = old_arg.device_map new_arg.model_max_length = old_arg.model_max_length + new_arg.block_diag_attn = old_arg.block_diag_attn return new_arg diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 507f7fef..73abc0bb 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -158,6 +158,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.stage != "sft" and training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True except SFT.") + if finetuning_args.stage != "sft" and data_args.neat_packing: + raise ValueError("`neat_packing` cannot be set as True except SFT.") + if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: raise ValueError("Please enable `predict_with_generate` to save model predictions.") @@ -170,9 +173,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.stage == "ppo" and model_args.shift_attn: raise ValueError("PPO training is incompatible with S^2-Attn.") - if finetuning_args.stage != "sft" and model_args.efficient_packing: - raise ValueError("`efficient_packing` cannot be set as True except SFT.") - if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: raise ValueError("Unsloth does not support lora reward model.") @@ -314,6 +314,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: model_args.device_map = {"": get_current_device()} model_args.model_max_length = data_args.cutoff_len + model_args.block_diag_attn = data_args.neat_packing data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt" # Log on each process the small summary diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index c12a70aa..0c3f9b11 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -28,6 +28,7 @@ from ..trainer_utils import create_modelcard_and_push from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor from .trainer import CustomSeq2SeqTrainer + if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback From e671ed520bdbbf99956acc9e8e1813ce620cf944 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 3 Jul 2024 23:23:24 +0800 Subject: [PATCH 09/12] update arg name Former-commit-id: 8a6a7b9c8a876da9c16e5ada7df461eb8cabee21 --- src/llamafactory/extras/constants.py | 20 +++++++------- src/llamafactory/model/model_utils/packing.py | 27 +++++-------------- src/llamafactory/model/patcher.py | 7 +++-- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 6029d84f..dc326d01 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -78,22 +78,22 @@ TRAINING_STAGES = { STAGES_USE_PAIR_DATA = {"rm", "dpo"} -SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} - -SUPPORTED_CLASS_EFFECIENT_PACKING = [ +SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { + "falcon", + "gemma", + "gemma2", + "jamba", "llama", "mistral", "mixtral", - "qwen2", - "qwen2_moe", - "falcon", "phi", "phi3", - "gemma", - "gemmoe", + "qwen2", + "qwen2_moe", "starcoder2", - "jamba" -] +} + +SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} V_HEAD_WEIGHTS_NAME = "value_head.bin" diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 606cd03b..c60547d4 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -283,28 +283,15 @@ def patch_for_multipack(model_type, model_name, attn_implementation): ) -def patch_remote(model_name, config_name, modeling_name): - model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - # we need to load the model here in order for modeling_* to be available - with init_empty_weights(): - AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) - module_name = model_config.__class__.__module__.replace(config_name, modeling_name) - modeling_arch = importlib.import_module(module_name) - modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access - # check exist load_balancing_loss_func for moe model - if hasattr(modeling_arch, "load_balancing_loss_func"): - modeling_arch.load_balancing_loss_func = load_balancing_loss_func +def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.block_diag_attn: + return + model_type = getattr(config, "model_type", None) -def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None: - if getattr(config, "model_type", None) == "internlm2": # special case for custom models - attn_implementation = getattr(config, "attn_implementation", "") - else: - attn_implementation = getattr(config, "_attn_implementation", "") - - if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING: - patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation) + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: + patch_for_block_diag_attn(model_type) logger.info("Using packing sequences without cross-contamination attention for efficient training.") else: - raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False") \ No newline at end of file + raise ValueError("Current model does not support packing sequences for efficient training.") diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 2ddfd21a..a99d38e0 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -29,11 +29,12 @@ from .model_utils.checkpointing import prepare_model_for_training from .model_utils.embedding import resize_embedding_layer from .model_utils.longlora import configure_longlora from .model_utils.moe import add_z3_leaf_module, configure_moe +from .model_utils.packing import configure_packing from .model_utils.quantization import configure_quantization from .model_utils.rope import configure_rope from .model_utils.valuehead import prepare_valuehead_model from .model_utils.visual import autocast_projector_dtype, configure_visual_model -from .model_utils.packing import configure_packing + if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer @@ -73,6 +74,7 @@ def patch_config( configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable) configure_visual_model(config) + configure_packing(config, model_args, is_trainable) if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) @@ -101,9 +103,6 @@ def patch_config( if init_kwargs.get("device_map", None) == "auto": init_kwargs["offload_folder"] = model_args.offload_folder - - if model_args.efficient_packing: - configure_packing(config, model_args) def patch_model( From 13cec0cc2f3b46ab4b6320b6462914bb3f39414a Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 3 Jul 2024 23:29:33 +0800 Subject: [PATCH 10/12] update func name Former-commit-id: c346f79f99db5296000e4d22a65e53c26e85b344 --- src/llamafactory/model/model_utils/packing.py | 47 ++++++++++++++++--- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index c60547d4..8ed313be 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -1,8 +1,43 @@ -# Copy from original implementation of src/axolotl/monkeypatch/multipack.py and src/axolotl/monkeypatch/utils.py from axolotl library with some changes -""" -Shared utils for the monkeypatches -""" -from typing import Optional, TYPE_CHECKING +# Copyright 2024 Musab Gultekin and the LlamaFactory team. +# +# This code is based on the Musab Gultekin's functionary library. +# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2023 Musab Gultekin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from typing import TYPE_CHECKING, Optional import torch import torch.nn.functional as F @@ -18,8 +53,8 @@ from transformers.modeling_attn_mask_utils import ( ) from transformers.utils import is_torch_bf16_gpu_available +from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN from ...extras.logging import get_logger -from ...extras.constants import SUPPORTED_CLASS_EFFECIENT_PACKING if TYPE_CHECKING: from transformers import PretrainedConfig From 51c75985b82d3188c49d463e7d6ac6bd81ca5dc1 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 3 Jul 2024 23:36:01 +0800 Subject: [PATCH 11/12] Update packing.py Former-commit-id: a36e8f2dd50e0f1c589457a7e785fdbc905d561d --- src/llamafactory/model/model_utils/packing.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 8ed313be..5b21bf4e 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -257,7 +257,7 @@ def load_balancing_loss_func( return overall_loss * num_experts -def patch_for_multipack(model_type, model_name, attn_implementation): +def patch_for_block_diag_attn(model_type, model_name, attn_implementation): if attn_implementation == "flash_attention_2": if model_type == "llama": transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access @@ -305,10 +305,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation): transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) - elif model_type == "gemmoe": - patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") - elif model_type == "jamba": - patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") else: transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa @@ -318,7 +314,6 @@ def patch_for_multipack(model_type, model_name, attn_implementation): ) - def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: if not is_trainable or not model_args.block_diag_attn: return From bfdaadcc402495fc0c0d7dcb72af53f8dae81e8c Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 4 Jul 2024 01:10:55 +0800 Subject: [PATCH 12/12] update packing Former-commit-id: cce7083024bed4c7429ddc8288d1c9190fde29f5 --- src/llamafactory/data/collator.py | 24 +- src/llamafactory/extras/constants.py | 3 - src/llamafactory/hparams/data_args.py | 2 +- src/llamafactory/model/model_utils/packing.py | 332 ++++-------------- tests/data/test_collator.py | 1 + tests/model/model_utils/test_packing.py | 42 +++ 6 files changed, 133 insertions(+), 271 deletions(-) create mode 100644 tests/model/model_utils/test_packing.py diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 0939925d..6d176313 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -29,20 +29,22 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype e.g. ``` - [1, 1, 2, 2, 2, 0] + [[1, 1, 2, 2, 2, 0]] ``` -> ``` - [[ - [ - [o, x, x, x, x, x], - [o, o, x, x, x, x], - [x, x, o, x, x, x], - [x, x, o, o, x, x], - [x, x, o, o, o, x], - [x, x, o, x, x, x], - ] - ]] + [ + [ + [ + [o, x, x, x, x, x], + [o, o, x, x, x, x], + [x, x, o, x, x, x], + [x, x, o, o, x, x], + [x, x, o, o, o, x], + [x, x, o, x, x, x], + ] + ] + ] ``` where `o` equals to `0.0`, `x` equals to `min_dtype`. """ diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index dc326d01..49aa4dba 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -82,14 +82,11 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { "falcon", "gemma", "gemma2", - "jamba", "llama", "mistral", - "mixtral", "phi", "phi3", "qwen2", - "qwen2_moe", "starcoder2", } diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 38bbbb12..45c1079b 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -83,7 +83,7 @@ class DataArguments: ) ignore_pad_token_for_loss: bool = field( default=True, - metadata={"help": "Whether or not to ignore the tokens corresponding to the pad tokens in loss computation."}, + metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."}, ) val_size: float = field( default=0.0, diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 5b21bf4e..ba614515 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -37,281 +37,102 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Tuple import torch import torch.nn.functional as F - -import importlib - -import transformers -from accelerate import init_empty_weights -from transformers import AutoConfig, AutoModelForCausalLM -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) -from transformers.utils import is_torch_bf16_gpu_available +import transformers.models from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN from ...extras.logging import get_logger + if TYPE_CHECKING: from transformers import PretrainedConfig - from ...hparams import ModelArguments, DataArguments + from ...hparams import ModelArguments logger = get_logger(__name__) -@torch.jit.script -def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: - max_num = int(torch.max(attention_mask).item()) - batch_size, _ = attention_mask.shape - counts = torch.zeros((batch_size, max_num), dtype=torch.int32) - - for i in range(1, max_num + 1): - mask = attention_mask == i - counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) - - result = counts.flatten() - nonzero_indices = torch.nonzero(result).squeeze(-1) - return result[nonzero_indices] - - -@torch.jit.script -def get_unpad_data(attention_mask: torch.Tensor): - device = attention_mask.device - seqlens_in_batch = get_max_seqlen_in_batch(attention_mask) - indices = torch.nonzero(attention_mask.flatten()).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = ( - F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - .to(device=device) - .detach() - ) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - -def mask_2d_to_4d( - mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None -): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - This expansion handles packed sequences so that sequences share the same attention mask integer value - when they attend to each other within that sequence. - This expansion transforms the mask to lower triangular form to prevent future peeking. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - mask = mask.unsqueeze(1).unsqueeze(2) - mask = mask.expand(bsz, 1, tgt_len, src_len) - - # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one - binary_mask = torch.where( - mask != 0, - torch.tensor(1, device=mask.device).to(dtype), - torch.tensor(0, device=mask.device).to(dtype), - ) - - # Create a block-diagonal mask. - # we multiply by the binary mask so that 0's in the original mask are correctly excluded - zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask - - # Now let's create a lower triangular mask of ones that will zero out the upper triangular part - lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( - mask.device - ) - - # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask - masked_zero_one_mask = zero_one_mask * lower_triangular_ones - - return masked_zero_one_mask - - -def patched_prepare_4d_causal_attention_mask( - attention_mask: Optional[torch.Tensor], - *args, -): - dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 - return _prepare_4d_causal_attention_mask( - mask_2d_to_4d(attention_mask, dtype=dtype), - *args, - ) - - -def patched_prepare_4d_causal_attention_mask_for_sdpa( - attention_mask: Optional[torch.Tensor], - *args, -): - dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 - return _prepare_4d_causal_attention_mask_for_sdpa( - mask_2d_to_4d(attention_mask, dtype=dtype), - *args, - ) - - -def set_module_name(model, name, value): - if "." in name: - parent_name = name.rsplit(".", 1)[0] - child_name = name[len(parent_name) + 1 :] - parent = model.get_submodule(parent_name) - else: - parent_name = "" - parent = model - child_name = name - - setattr(parent, child_name, value) - - -# Copy from original implementation of modeling_mixtral.py from transformers, Just change a little bit with new_attention_mask -def load_balancing_loss_func( - gate_logits: torch.Tensor, - num_experts: torch.Tensor = None, - top_k=2, - attention_mask: Optional[torch.Tensor] = None, -) -> float: +def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": r""" - Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + Gets the sequnce lengths in the current batch. - See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss - function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between - experts is too unbalanced. + e.g. + ``` + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ``` + -> + ``` + [2, 3, 1, 2, 3] + ``` + """ + bsz = attention_mask.size(0) + dtype, device = attention_mask.dtype, attention_mask.device + max_num = torch.max(attention_mask) + counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device) + for i in range(max_num): + counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1) - Args: - gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): - Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of - shape [batch_size X sequence_length, num_experts]. - attention_mask (`torch.Tensor`, None): - The attention_mask used in forward function - shape [batch_size X sequence_length] if not None. - num_experts (`int`, *optional*): - Number of experts + counts = counts.flatten() + seqlens = counts[counts.nonzero().squeeze()] + return seqlens + + +def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]: + r""" + Prepares the indices and seqlens for flash attn varlen function. Returns: - The auxiliary loss. + indices: indices of non-masked tokens from the flattened sequence. + cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0. + max_seqlen_in_batch: the largest seqlen in the current batch. + + e.g. + ``` + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ``` + -> + ``` + [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11] + [0, 2, 5, 6, 8, 11] + 3 + ``` """ - if gate_logits is None or not isinstance(gate_logits, tuple): - return 0 - - if isinstance(gate_logits, tuple): - compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat( - [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 - ) - - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) - - _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - - expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) - - if attention_mask is None: - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) - - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.mean(routing_weights, dim=0) - else: - # ONLY ADD THIS LINE OF CODE, AND REPLACE attention_mask WITH new_attention_mask - new_attention_mask = (attention_mask != 0).int().to(attention_mask.device) - batch_size, sequence_length = new_attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // ( - batch_size * sequence_length - ) - - # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask - expert_attention_mask = ( - new_attention_mask[None, :, :, None, None] - .expand( - (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) - ) - .reshape(-1, top_k, num_experts) - .to(compute_device) - ) - - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum( - expert_mask.float() * expert_attention_mask, dim=0 - ) / torch.sum(expert_attention_mask, dim=0) - - # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert - router_per_expert_attention_mask = ( - new_attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) - .to(compute_device) - ) - - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.sum( - routing_weights * router_per_expert_attention_mask, dim=0 - ) / torch.sum(router_per_expert_attention_mask, dim=0) - - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) - return overall_loss * num_experts + seqlens_in_batch = get_seqlens_in_batch(attention_mask) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return indices, cu_seqlens, max_seqlen_in_batch -def patch_for_block_diag_attn(model_type, model_name, attn_implementation): - if attn_implementation == "flash_attention_2": - if model_type == "llama": - transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "mistral": - transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "mixtral": - transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access - load_balancing_loss_func - ) - elif model_type == "qwen2": - transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "qwen2_moe": - transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access - load_balancing_loss_func - ) - elif model_type == "falcon": - transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "phi": - transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "phi3": - transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "gemma": - transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "starcoder2": - transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - else: - transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access - patched_prepare_4d_causal_attention_mask_for_sdpa - ) - transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access - patched_prepare_4d_causal_attention_mask - ) +def patch_for_block_diag_attn(model_type: str) -> None: + if model_type == "falcon": + transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data + elif model_type == "gemma": + transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data + elif model_type == "gemma2": + transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data + elif model_type == "llama": + transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data + elif model_type == "mistral": + transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data + elif model_type == "phi": + transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data + elif model_type == "phi3": + transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data + elif model_type == "qwen2": + transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data + elif model_type == "starcoder2": + transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: @@ -319,9 +140,8 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", return model_type = getattr(config, "model_type", None) - - if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: + if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: patch_for_block_diag_attn(model_type) - logger.info("Using packing sequences without cross-contamination attention for efficient training.") + logger.info("Using block diagonal attention for sequence packing without cross-attention.") else: - raise ValueError("Current model does not support packing sequences for efficient training.") + raise ValueError("Current model does not support block diagonal attention.") diff --git a/tests/data/test_collator.py b/tests/data/test_collator.py index cb473d4c..58035ac2 100644 --- a/tests/data/test_collator.py +++ b/tests/data/test_collator.py @@ -52,4 +52,5 @@ def test_4d_attention_mask(): ], dtype=torch.float16, ) + assert list(attention_mask_computed.size()) == [2, 1, 6, 6] assert torch.all(attention_mask_computed == attention_mask_expected) diff --git a/tests/model/model_utils/test_packing.py b/tests/model/model_utils/test_packing.py new file mode 100644 index 00000000..6fd9ba3b --- /dev/null +++ b/tests/model/model_utils/test_packing.py @@ -0,0 +1,42 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data + + +def test_get_seqlens_in_batch(): + attention_mask_with_indices = torch.tensor( + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ) + seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices) + assert list(seqlens_in_batch.size()) == [5] + assert torch.all(seqlens_in_batch == torch.tensor([2, 3, 1, 2, 3])) + + +def test_get_unpad_data(): + attention_mask_with_indices = torch.tensor( + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + ) + indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices) + assert torch.all(indices == torch.tensor([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11])) + assert torch.all(cu_seqlens == torch.tensor([0, 2, 5, 6, 8, 11], dtype=torch.int32)) + assert max_seqlen_in_batch == 3