diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 97789c39..cf207d7e 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -36,7 +36,7 @@ def get_preprocess_and_print_func( ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not training_args.predict_with_generate: - if data_args.packing: + if data_args.packing or data_args.efficient_packing: preprocess_func = partial( preprocess_packed_supervised_dataset, template=template, diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 19d60280..3406576b 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from transformers import ProcessorMixin from transformers.tokenization_utils import PreTrainedTokenizer - from ...hparams import DataArguments + from ...hparams import DataArguments, FinetuningArguments from ..template import Template @@ -140,11 +140,12 @@ def preprocess_packed_supervised_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) for knapsack in knapsacks: - packed_input_ids, packed_labels = [], [] - for length in knapsack: + packed_input_ids, packed_attention_mask, packed_labels = [], [], [] + for i, length in enumerate(knapsack): index = length2indexes[length].pop() packed_input_ids += batch_input_ids[index] packed_labels += batch_labels[index] + packed_attention_mask += [i+1]*len(batch_input_ids[index]) if len(packed_input_ids) < data_args.cutoff_len: pad_length = data_args.cutoff_len - len(packed_input_ids) @@ -155,7 +156,10 @@ def preprocess_packed_supervised_dataset( raise ValueError("The length of packed example should be identical to the cutoff length.") model_inputs["input_ids"].append(packed_input_ids) - model_inputs["attention_mask"].append([1] * data_args.cutoff_len) + if data_args.efficient_packing: + model_inputs["attention_mask"].append(packed_attention_mask) + else: + model_inputs["attention_mask"].append([1] * data_args.cutoff_len) model_inputs["labels"].append(packed_labels) return model_inputs diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 1e0cd08c..2081c6b4 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -84,6 +84,10 @@ class DataArguments: "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." }, ) + efficient_packing: Optional[bool] = field( + default=None, + metadata={"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."}, + ) tokenized_path: Optional[str] = field( default=None, metadata={"help": "Path to save or load the tokenized datasets."}, diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py new file mode 100644 index 00000000..fe718ebb --- /dev/null +++ b/src/llamafactory/model/model_utils/packing.py @@ -0,0 +1,250 @@ +# Copy from original implementation of src/axolotl/monkeypatch/multipack.py and src/axolotl/monkeypatch/utils.py from axolotl library with some changes +""" +Shared utils for the monkeypatches +""" +from typing import Optional, TYPE_CHECKING + +import torch +import torch.nn.functional as F + +import importlib + +import transformers +from accelerate import init_empty_weights +from transformers import AutoConfig, AutoModelForCausalLM +from ...extras.logging import get_logger + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments, DataArguments + + +SUPPORTED_MULTIPACK_MODEL_TYPES = [ + "llama", + "mistral", + "mixtral", + "qwen2", + "qwen2_moe", + "falcon", + "phi", + "phi3", + "gemma", + "gemmoe", + "starcoder2", +] + + +@torch.jit.script +def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: + max_num = int(torch.max(attention_mask).item()) + batch_size, _ = attention_mask.shape + counts = torch.zeros((batch_size, max_num), dtype=torch.int32) + + for i in range(1, max_num + 1): + mask = attention_mask == i + counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) + + result = counts.flatten() + nonzero_indices = torch.nonzero(result).squeeze(-1) + return result[nonzero_indices] + + +@torch.jit.script +def get_unpad_data(attention_mask: torch.Tensor): + device = attention_mask.device + seqlens_in_batch = get_max_seqlen_in_batch(attention_mask) + indices = torch.nonzero(attention_mask.flatten()).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = ( + F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + .to(device=device) + .detach() + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def set_module_name(model, name, value): + if "." in name: + parent_name = name.rsplit(".", 1)[0] + child_name = name[len(parent_name) + 1 :] + parent = model.get_submodule(parent_name) + else: + parent_name = "" + parent = model + child_name = name + + setattr(parent, child_name, value) + + +# Copy from original implementation of modeling_mixtral.py from transformers, Just change a little bit with new_attention_mask +def load_balancing_loss_func( + gate_logits: torch.Tensor, + num_experts: torch.Tensor = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + # ONLY ADD THIS LINE OF CODE, AND REPLACE attention_mask WITH new_attention_mask + new_attention_mask = (attention_mask != 0).int().to(attention_mask.device) + batch_size, sequence_length = new_attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + new_attention_mask[None, :, :, None, None] + .expand( + (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) + ) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + new_attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +def patch_for_multipack(model_type, model_name=None): + if model_type == "llama": + transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "mistral": + transformers.models.mistral.modeling_mistral._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "mixtral": + transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( # pylint: disable=protected-access + load_balancing_loss_func + ) + elif model_type == "qwen2": + transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "qwen2_moe": + transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func = ( # pylint: disable=protected-access + load_balancing_loss_func + ) + elif model_type == "falcon": + transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "phi": + transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "phi3": + transformers.models.phi3.modeling_phi3._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "gemma": + transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "starcoder2": + transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "gemmoe": + patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") + elif model_type == "jamba": + patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") + + +def patch_remote(model_name, config_name, modeling_name): + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + # we need to load the model here in order for modeling_* to be available + with init_empty_weights(): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + module_name = model_config.__class__.__module__.replace(config_name, modeling_name) + modeling_arch = importlib.import_module(module_name) + modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access + + +def configure_packing(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + attn_implementation = getattr(config, "attn_implementation", None) + else: + attn_implementation = getattr(config, "_attn_implementation", None) + + if attn_implementation != "flash_attention_2": + raise ValueError("Efficient packing only supports for flash_attention_2. Please set config `flash_attn` is fa2" + " " + attn_implementation) + + logger = get_logger(__name__) + + if getattr(config, "model_type", None) in SUPPORTED_MULTIPACK_MODEL_TYPES: + patch_for_multipack(getattr(config, "model_type", None)) + logger.info("Using packing sequences without cross-contamination attention for efficient training.") + else: + raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False") \ No newline at end of file diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index c79b160b..f003e157 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -24,7 +24,7 @@ def run_kto( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) + dataset = get_dataset(model_args, data_args, training_args, finetuning_args, stage="kto", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = KTODataCollatorWithPadding( diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index f09b5173..d4965393 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push from .metric import ComputeMetrics from .trainer import CustomSeq2SeqTrainer - +from ...model.model_utils.packing import configure_packing if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -33,6 +33,9 @@ def run_sft( dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + if data_args.efficient_packing: + configure_packing(model.config) + if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 72dfc858..dccc8500 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -83,6 +83,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Column(): resize_vocab = gr.Checkbox() packing = gr.Checkbox() + efficient_packing = gr.Checkbox() with gr.Column(): upcast_layernorm = gr.Checkbox() @@ -101,6 +102,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: optim, resize_vocab, packing, + efficient_packing, upcast_layernorm, use_llama_pro, shift_attn, @@ -117,6 +119,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: optim=optim, resize_vocab=resize_vocab, packing=packing, + efficient_packing=efficient_packing, upcast_layernorm=upcast_layernorm, use_llama_pro=use_llama_pro, shift_attn=shift_attn, @@ -313,7 +316,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False) - training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False) + training_stage.change(change_stage, [training_stage], [dataset, packing, efficient_packing], queue=False) reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False) model_name.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index e30feab2..05cf3bed 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -494,6 +494,20 @@ LOCALES = { "info": "将序列打包为等长样本。", }, }, + "efficient_packing": { + "en": { + "label": "Pack sequences for efficient training", + "info": "Pack sequences into samples of fixed length without cross-contamination attention for efficient training.", + }, + "ru": { + "label": "Пакетные последовательности для эффективного обучения", + "info": "Упакуйте последовательности в образцы фиксированной длины без учета перекрестного загрязнения для эффективного обучения.", + }, + "zh": { + "label": "打包序列以实现高效训练", + "info": "为了提高训练效率,将序列打包成固定长度的样本,无需注意交叉污染。", + }, + }, "upcast_layernorm": { "en": { "label": "Upcast LayerNorm", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 35014628..852805da 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -120,6 +120,7 @@ class Runner: optim=get("train.optim"), resize_vocab=get("train.resize_vocab"), packing=get("train.packing"), + efficient_packing=get("train.efficient_packing"), upcast_layernorm=get("train.upcast_layernorm"), use_llama_pro=get("train.use_llama_pro"), shift_attn=get("train.shift_attn"),