From 7b3c1f29ffb049f06fa24d6459de64e718677ca2 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 4 Jul 2024 01:52:43 +0800 Subject: [PATCH] fix packing for eager/sdpa attn Former-commit-id: 6fd6aa4530f81a2ed306eeb2a5167607288b62c6 --- src/llamafactory/data/__init__.py | 3 ++- src/llamafactory/data/collator.py | 24 ++++++++++++++++--- src/llamafactory/extras/constants.py | 1 + src/llamafactory/hparams/data_args.py | 3 --- src/llamafactory/hparams/finetuning_args.py | 19 ++++++++++----- src/llamafactory/hparams/parser.py | 4 ++++ src/llamafactory/model/model_utils/packing.py | 4 +++- src/llamafactory/train/callbacks.py | 4 ++-- src/llamafactory/train/sft/workflow.py | 9 +++---- 9 files changed, 51 insertions(+), 20 deletions(-) diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index 307853bc..4da742b4 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding +from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask from .data_utils import Role, split_dataset from .loader import get_dataset from .template import TEMPLATES, Template, get_template_and_fix_tokenizer @@ -21,6 +21,7 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer __all__ = [ "KTODataCollatorWithPadding", "PairwiseDataCollatorWithPadding", + "SFTDataCollatorWith4DAttentionMask", "Role", "split_dataset", "get_dataset", diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 6d176313..91871eaa 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -16,7 +16,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Dict, Sequence +from typing import Any, Dict, Literal, Sequence import torch from transformers import DataCollatorForSeq2Seq @@ -62,13 +62,31 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype return attention_mask_4d +@dataclass +class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq): + r""" + Data collator for 4d attention mask. + """ + + block_diag_attn: bool = False + attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" + compute_dtype: "torch.dtype" = torch.float32 + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + features = super().__call__(features) + if self.block_diag_attn and self.attn_implementation != "flash_attention_2": + features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) + + return features + + @dataclass class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): r""" Data collator for pairwise data. """ - def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: r""" Pads batched data to the longest sequence in the batch. @@ -100,7 +118,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): Data collator for KTO data. """ - def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: target_features = [] kl_features = [] kto_tags = [] diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 47781791..1ac61305 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -79,6 +79,7 @@ TRAINING_STAGES = { STAGES_USE_PAIR_DATA = {"rm", "dpo"} SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { + "cohere", "falcon", "gemma", "gemma2", diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 45c1079b..a1025af7 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -112,6 +112,3 @@ class DataArguments: if self.streaming and self.max_samples is not None: raise ValueError("`max_samples` is incompatible with `streaming`.") - - if self.neat_packing and not self.packing: - raise ValueError("`neat_packing` requires `packing` is True.") diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 3867c0ec..923cc431 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -376,14 +376,21 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA if self.use_galore and self.use_badam: raise ValueError("Cannot use GaLore with BAdam together.") - if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora": - raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") - - if self.pissa_init and self.finetuning_type != "lora": - raise ValueError("`pissa_init` is only valid for LoRA training.") - if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model): raise ValueError("Cannot use PiSSA for current training stage.") if self.train_mm_proj_only and self.finetuning_type != "full": raise ValueError("`train_mm_proj_only` is only valid for full training.") + + if self.finetuning_type != "lora": + if self.loraplus_lr_ratio is not None: + raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") + + if self.use_rslora: + raise ValueError("`use_rslora` is only valid for LoRA training.") + + if self.use_dora: + raise ValueError("`use_dora` is only valid for LoRA training.") + + if self.pissa_init: + raise ValueError("`pissa_init` is only valid for LoRA training.") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 73abc0bb..ca9a9589 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -233,6 +233,10 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if model_args.use_unsloth and is_deepspeed_zero3_enabled(): raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") + if data_args.neat_packing and not data_args.packing: + logger.warning("`neat_packing` requires `packing` is True. Change it to True.") + data_args.packing = True + _verify_model_args(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args, training_args) diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index ba614515..eec5d957 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -115,7 +115,9 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor def patch_for_block_diag_attn(model_type: str) -> None: - if model_type == "falcon": + if model_type == "cohere": + transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data + elif model_type == "falcon": transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data elif model_type == "gemma": transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 4d024278..97eb6d1c 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -79,9 +79,8 @@ def fix_valuehead_checkpoint( if name.startswith("v_head."): v_head_state_dict[name] = param else: - decoder_state_dict[name.replace("pretrained_model.", "")] = param + decoder_state_dict[name.replace("pretrained_model.", "", count=1)] = param - os.remove(path_to_checkpoint) model.pretrained_model.save_pretrained( output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization ) @@ -91,6 +90,7 @@ def fix_valuehead_checkpoint( else: torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) + os.remove(path_to_checkpoint) logger.info("Value head model saved at: {}".format(output_dir)) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 0c3f9b11..dea3c1a8 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -17,9 +17,7 @@ from typing import TYPE_CHECKING, List, Optional -from transformers import DataCollatorForSeq2Seq - -from ...data import get_dataset, split_dataset +from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, split_dataset from ...extras.constants import IGNORE_INDEX from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss @@ -54,10 +52,13 @@ def run_sft( if getattr(model, "is_quantized", False) and not training_args.do_train: setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction - data_collator = DataCollatorForSeq2Seq( + data_collator = SFTDataCollatorWith4DAttentionMask( tokenizer=tokenizer, pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + block_diag_attn=model_args.block_diag_attn, + attn_implementation=getattr(model.config, "_attn_implementation", None), + compute_dtype=model_args.compute_dtype, ) # Override the decoding parameters of Seq2SeqTrainer