mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
implement efficient packing without cross-contamination attention
Former-commit-id: b2c367bc61c2778dc359613dca496d9e134c2743
This commit is contained in:
parent
1a261add61
commit
045eb155a2
@ -36,7 +36,7 @@ def get_preprocess_and_print_func(
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
if data_args.packing:
|
||||
if data_args.packing or data_args.efficient_packing:
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset,
|
||||
template=template,
|
||||
|
@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ...hparams import DataArguments, FinetuningArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
@ -140,11 +140,12 @@ def preprocess_packed_supervised_dataset(
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_labels = [], []
|
||||
for length in knapsack:
|
||||
packed_input_ids, packed_attention_mask, packed_labels = [], [], []
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
packed_labels += batch_labels[index]
|
||||
packed_attention_mask += [i+1]*len(batch_input_ids[index])
|
||||
|
||||
if len(packed_input_ids) < data_args.cutoff_len:
|
||||
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
||||
@ -155,7 +156,10 @@ def preprocess_packed_supervised_dataset(
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
||||
if data_args.efficient_packing:
|
||||
model_inputs["attention_mask"].append(packed_attention_mask)
|
||||
else:
|
||||
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
|
||||
return model_inputs
|
||||
|
@ -84,6 +84,10 @@ class DataArguments:
|
||||
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
||||
},
|
||||
)
|
||||
efficient_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."},
|
||||
)
|
||||
tokenized_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the tokenized datasets."},
|
||||
|
250
src/llamafactory/model/model_utils/packing.py
Normal file
250
src/llamafactory/model/model_utils/packing.py
Normal file
@ -0,0 +1,250 @@
|
||||
# Copy from original implementation of src/axolotl/monkeypatch/multipack.py and src/axolotl/monkeypatch/utils.py from axolotl library with some changes
|
||||
"""
|
||||
Shared utils for the monkeypatches
|
||||
"""
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import importlib
|
||||
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ...hparams import ModelArguments, DataArguments
|
||||
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"llama",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
"falcon",
|
||||
"phi",
|
||||
"phi3",
|
||||
"gemma",
|
||||
"gemmoe",
|
||||
"starcoder2",
|
||||
]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
max_num = int(torch.max(attention_mask).item())
|
||||
batch_size, _ = attention_mask.shape
|
||||
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
|
||||
|
||||
for i in range(1, max_num + 1):
|
||||
mask = attention_mask == i
|
||||
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
|
||||
|
||||
result = counts.flatten()
|
||||
nonzero_indices = torch.nonzero(result).squeeze(-1)
|
||||
return result[nonzero_indices]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_unpad_data(attention_mask: torch.Tensor):
|
||||
device = attention_mask.device
|
||||
seqlens_in_batch = get_max_seqlen_in_batch(attention_mask)
|
||||
indices = torch.nonzero(attention_mask.flatten()).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = (
|
||||
F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
.to(device=device)
|
||||
.detach()
|
||||
)
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def set_module_name(model, name, value):
|
||||
if "." in name:
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
child_name = name[len(parent_name) + 1 :]
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent_name = ""
|
||||
parent = model
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, value)
|
||||
|
||||
|
||||
# Copy from original implementation of modeling_mixtral.py from transformers, Just change a little bit with new_attention_mask
|
||||
def load_balancing_loss_func(
|
||||
gate_logits: torch.Tensor,
|
||||
num_experts: torch.Tensor = None,
|
||||
top_k=2,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> float:
|
||||
r"""
|
||||
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
||||
|
||||
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
||||
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
||||
experts is too unbalanced.
|
||||
|
||||
Args:
|
||||
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
|
||||
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
||||
shape [batch_size X sequence_length, num_experts].
|
||||
attention_mask (`torch.Tensor`, None):
|
||||
The attention_mask used in forward function
|
||||
shape [batch_size X sequence_length] if not None.
|
||||
num_experts (`int`, *optional*):
|
||||
Number of experts
|
||||
|
||||
Returns:
|
||||
The auxiliary loss.
|
||||
"""
|
||||
if gate_logits is None or not isinstance(gate_logits, tuple):
|
||||
return 0
|
||||
|
||||
if isinstance(gate_logits, tuple):
|
||||
compute_device = gate_logits[0].device
|
||||
concatenated_gate_logits = torch.cat(
|
||||
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
|
||||
)
|
||||
|
||||
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
||||
|
||||
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
||||
|
||||
if attention_mask is None:
|
||||
# Compute the percentage of tokens routed to each experts
|
||||
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
||||
|
||||
# Compute the average probability of routing to these experts
|
||||
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
||||
else:
|
||||
# ONLY ADD THIS LINE OF CODE, AND REPLACE attention_mask WITH new_attention_mask
|
||||
new_attention_mask = (attention_mask != 0).int().to(attention_mask.device)
|
||||
batch_size, sequence_length = new_attention_mask.shape
|
||||
num_hidden_layers = concatenated_gate_logits.shape[0] // (
|
||||
batch_size * sequence_length
|
||||
)
|
||||
|
||||
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
||||
expert_attention_mask = (
|
||||
new_attention_mask[None, :, :, None, None]
|
||||
.expand(
|
||||
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
|
||||
)
|
||||
.reshape(-1, top_k, num_experts)
|
||||
.to(compute_device)
|
||||
)
|
||||
|
||||
# Compute the percentage of tokens routed to each experts
|
||||
tokens_per_expert = torch.sum(
|
||||
expert_mask.float() * expert_attention_mask, dim=0
|
||||
) / torch.sum(expert_attention_mask, dim=0)
|
||||
|
||||
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
||||
router_per_expert_attention_mask = (
|
||||
new_attention_mask[None, :, :, None]
|
||||
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
||||
.reshape(-1, num_experts)
|
||||
.to(compute_device)
|
||||
)
|
||||
|
||||
# Compute the average probability of routing to these experts
|
||||
router_prob_per_expert = torch.sum(
|
||||
routing_weights * router_per_expert_attention_mask, dim=0
|
||||
) / torch.sum(router_per_expert_attention_mask, dim=0)
|
||||
|
||||
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
||||
return overall_loss * num_experts
|
||||
|
||||
|
||||
def patch_for_multipack(model_type, model_name=None):
|
||||
if model_type == "llama":
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "mistral":
|
||||
transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access
|
||||
load_balancing_loss_func
|
||||
)
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2_moe":
|
||||
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access
|
||||
load_balancing_loss_func
|
||||
)
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi":
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi3":
|
||||
transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma":
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemmoe":
|
||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||
elif model_type == "jamba":
|
||||
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
||||
|
||||
|
||||
def patch_remote(model_name, config_name, modeling_name):
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_* to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
||||
modeling_arch = importlib.import_module(module_name)
|
||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
||||
|
||||
|
||||
def configure_packing(config: "PretrainedConfig") -> None:
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
attn_implementation = getattr(config, "attn_implementation", None)
|
||||
else:
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
if attn_implementation != "flash_attention_2":
|
||||
raise ValueError("Efficient packing only supports for flash_attention_2. Please set config `flash_attn` is fa2" + " " + attn_implementation)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
patch_for_multipack(getattr(config, "model_type", None))
|
||||
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
|
||||
else:
|
||||
raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False")
|
@ -24,7 +24,7 @@ def run_kto(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
|
||||
dataset = get_dataset(model_args, data_args, training_args, finetuning_args, stage="kto", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = KTODataCollatorWithPadding(
|
||||
|
@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
from .metric import ComputeMetrics
|
||||
from .trainer import CustomSeq2SeqTrainer
|
||||
|
||||
from ...model.model_utils.packing import configure_packing
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
@ -33,6 +33,9 @@ def run_sft(
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
if data_args.efficient_packing:
|
||||
configure_packing(model.config)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left" # use left-padding in generation
|
||||
|
||||
|
@ -83,6 +83,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Column():
|
||||
resize_vocab = gr.Checkbox()
|
||||
packing = gr.Checkbox()
|
||||
efficient_packing = gr.Checkbox()
|
||||
|
||||
with gr.Column():
|
||||
upcast_layernorm = gr.Checkbox()
|
||||
@ -101,6 +102,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
optim,
|
||||
resize_vocab,
|
||||
packing,
|
||||
efficient_packing,
|
||||
upcast_layernorm,
|
||||
use_llama_pro,
|
||||
shift_attn,
|
||||
@ -117,6 +119,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
optim=optim,
|
||||
resize_vocab=resize_vocab,
|
||||
packing=packing,
|
||||
efficient_packing=efficient_packing,
|
||||
upcast_layernorm=upcast_layernorm,
|
||||
use_llama_pro=use_llama_pro,
|
||||
shift_attn=shift_attn,
|
||||
@ -313,7 +316,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
|
||||
dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False)
|
||||
training_stage.change(change_stage, [training_stage], [dataset, packing, efficient_packing], queue=False)
|
||||
reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False)
|
||||
model_name.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False)
|
||||
finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False)
|
||||
|
@ -494,6 +494,20 @@ LOCALES = {
|
||||
"info": "将序列打包为等长样本。",
|
||||
},
|
||||
},
|
||||
"efficient_packing": {
|
||||
"en": {
|
||||
"label": "Pack sequences for efficient training",
|
||||
"info": "Pack sequences into samples of fixed length without cross-contamination attention for efficient training.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Пакетные последовательности для эффективного обучения",
|
||||
"info": "Упакуйте последовательности в образцы фиксированной длины без учета перекрестного загрязнения для эффективного обучения.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "打包序列以实现高效训练",
|
||||
"info": "为了提高训练效率,将序列打包成固定长度的样本,无需注意交叉污染。",
|
||||
},
|
||||
},
|
||||
"upcast_layernorm": {
|
||||
"en": {
|
||||
"label": "Upcast LayerNorm",
|
||||
|
@ -120,6 +120,7 @@ class Runner:
|
||||
optim=get("train.optim"),
|
||||
resize_vocab=get("train.resize_vocab"),
|
||||
packing=get("train.packing"),
|
||||
efficient_packing=get("train.efficient_packing"),
|
||||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
use_llama_pro=get("train.use_llama_pro"),
|
||||
shift_attn=get("train.shift_attn"),
|
||||
|
Loading…
x
Reference in New Issue
Block a user