implement efficient packing without cross-contamination attention

Former-commit-id: b2c367bc61c2778dc359613dca496d9e134c2743
This commit is contained in:
ancv 2024-06-12 11:56:01 +07:00
parent 1a261add61
commit 045eb155a2
9 changed files with 287 additions and 8 deletions

View File

@ -36,7 +36,7 @@ def get_preprocess_and_print_func(
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing:
if data_args.packing or data_args.efficient_packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset,
template=template,

View File

@ -10,7 +10,7 @@ if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ...hparams import DataArguments, FinetuningArguments
from ..template import Template
@ -140,11 +140,12 @@ def preprocess_packed_supervised_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_labels = [], []
for length in knapsack:
packed_input_ids, packed_attention_mask, packed_labels = [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_attention_mask += [i+1]*len(batch_input_ids[index])
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
@ -155,7 +156,10 @@ def preprocess_packed_supervised_dataset(
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
if data_args.efficient_packing:
model_inputs["attention_mask"].append(packed_attention_mask)
else:
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
model_inputs["labels"].append(packed_labels)
return model_inputs

View File

@ -84,6 +84,10 @@ class DataArguments:
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
},
)
efficient_packing: Optional[bool] = field(
default=None,
metadata={"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."},
)
tokenized_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the tokenized datasets."},

View File

@ -0,0 +1,250 @@
# Copy from original implementation of src/axolotl/monkeypatch/multipack.py and src/axolotl/monkeypatch/utils.py from axolotl library with some changes
"""
Shared utils for the monkeypatches
"""
from typing import Optional, TYPE_CHECKING
import torch
import torch.nn.functional as F
import importlib
import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments, DataArguments
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
"qwen2",
"qwen2_moe",
"falcon",
"phi",
"phi3",
"gemma",
"gemmoe",
"starcoder2",
]
@torch.jit.script
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
max_num = int(torch.max(attention_mask).item())
batch_size, _ = attention_mask.shape
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
for i in range(1, max_num + 1):
mask = attention_mask == i
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
result = counts.flatten()
nonzero_indices = torch.nonzero(result).squeeze(-1)
return result[nonzero_indices]
@torch.jit.script
def get_unpad_data(attention_mask: torch.Tensor):
device = attention_mask.device
seqlens_in_batch = get_max_seqlen_in_batch(attention_mask)
indices = torch.nonzero(attention_mask.flatten()).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = (
F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
.to(device=device)
.detach()
)
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def set_module_name(model, name, value):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, value)
# Copy from original implementation of modeling_mixtral.py from transformers, Just change a little bit with new_attention_mask
def load_balancing_loss_func(
gate_logits: torch.Tensor,
num_experts: torch.Tensor = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> float:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
attention_mask (`torch.Tensor`, None):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
Returns:
The auxiliary loss.
"""
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat(
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
)
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
if attention_mask is None:
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
# ONLY ADD THIS LINE OF CODE, AND REPLACE attention_mask WITH new_attention_mask
new_attention_mask = (attention_mask != 0).int().to(attention_mask.device)
batch_size, sequence_length = new_attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (
batch_size * sequence_length
)
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
expert_attention_mask = (
new_attention_mask[None, :, :, None, None]
.expand(
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
)
.reshape(-1, top_k, num_experts)
.to(compute_device)
)
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.sum(
expert_mask.float() * expert_attention_mask, dim=0
) / torch.sum(expert_attention_mask, dim=0)
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
new_attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(compute_device)
)
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.sum(
routing_weights * router_per_expert_attention_mask, dim=0
) / torch.sum(router_per_expert_attention_mask, dim=0)
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts
def patch_for_multipack(model_type, model_name=None):
if model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access
load_balancing_loss_func
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access
load_balancing_loss_func
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi3":
transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
def patch_remote(model_name, config_name, modeling_name):
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_* to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
modeling_arch = importlib.import_module(module_name)
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
def configure_packing(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", None)
else:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation != "flash_attention_2":
raise ValueError("Efficient packing only supports for flash_attention_2. Please set config `flash_attn` is fa2" + " " + attn_implementation)
logger = get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_MULTIPACK_MODEL_TYPES:
patch_for_multipack(getattr(config, "model_type", None))
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
else:
raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False")

View File

@ -24,7 +24,7 @@ def run_kto(
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
dataset = get_dataset(model_args, data_args, training_args, finetuning_args, stage="kto", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = KTODataCollatorWithPadding(

View File

@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeMetrics
from .trainer import CustomSeq2SeqTrainer
from ...model.model_utils.packing import configure_packing
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -33,6 +33,9 @@ def run_sft(
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if data_args.efficient_packing:
configure_packing(model.config)
if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation

View File

@ -83,6 +83,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Column():
resize_vocab = gr.Checkbox()
packing = gr.Checkbox()
efficient_packing = gr.Checkbox()
with gr.Column():
upcast_layernorm = gr.Checkbox()
@ -101,6 +102,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
optim,
resize_vocab,
packing,
efficient_packing,
upcast_layernorm,
use_llama_pro,
shift_attn,
@ -117,6 +119,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
optim=optim,
resize_vocab=resize_vocab,
packing=packing,
efficient_packing=efficient_packing,
upcast_layernorm=upcast_layernorm,
use_llama_pro=use_llama_pro,
shift_attn=shift_attn,
@ -313,7 +316,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False)
training_stage.change(change_stage, [training_stage], [dataset, packing, efficient_packing], queue=False)
reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False)
model_name.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False)
finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False)

View File

@ -494,6 +494,20 @@ LOCALES = {
"info": "将序列打包为等长样本。",
},
},
"efficient_packing": {
"en": {
"label": "Pack sequences for efficient training",
"info": "Pack sequences into samples of fixed length without cross-contamination attention for efficient training.",
},
"ru": {
"label": "Пакетные последовательности для эффективного обучения",
"info": "Упакуйте последовательности в образцы фиксированной длины без учета перекрестного загрязнения для эффективного обучения.",
},
"zh": {
"label": "打包序列以实现高效训练",
"info": "为了提高训练效率,将序列打包成固定长度的样本,无需注意交叉污染。",
},
},
"upcast_layernorm": {
"en": {
"label": "Upcast LayerNorm",

View File

@ -120,6 +120,7 @@ class Runner:
optim=get("train.optim"),
resize_vocab=get("train.resize_vocab"),
packing=get("train.packing"),
efficient_packing=get("train.efficient_packing"),
upcast_layernorm=get("train.upcast_layernorm"),
use_llama_pro=get("train.use_llama_pro"),
shift_attn=get("train.shift_attn"),