fix packing for eager/sdpa attn

Former-commit-id: 6fd6aa4530f81a2ed306eeb2a5167607288b62c6
This commit is contained in:
hiyouga 2024-07-04 01:52:43 +08:00
parent a38ff842d0
commit 7b3c1f29ff
9 changed files with 51 additions and 20 deletions

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask
from .data_utils import Role, split_dataset from .data_utils import Role, split_dataset
from .loader import get_dataset from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
@ -21,6 +21,7 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [ __all__ = [
"KTODataCollatorWithPadding", "KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding", "PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role", "Role",
"split_dataset", "split_dataset",
"get_dataset", "get_dataset",

View File

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Sequence from typing import Any, Dict, Literal, Sequence
import torch import torch
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
@ -62,13 +62,31 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
return attention_mask_4d return attention_mask_4d
@dataclass
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
return features
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
""" """
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
r""" r"""
Pads batched data to the longest sequence in the batch. Pads batched data to the longest sequence in the batch.
@ -100,7 +118,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
Data collator for KTO data. Data collator for KTO data.
""" """
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
target_features = [] target_features = []
kl_features = [] kl_features = []
kto_tags = [] kto_tags = []

View File

@ -79,6 +79,7 @@ TRAINING_STAGES = {
STAGES_USE_PAIR_DATA = {"rm", "dpo"} STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
"cohere",
"falcon", "falcon",
"gemma", "gemma",
"gemma2", "gemma2",

View File

@ -112,6 +112,3 @@ class DataArguments:
if self.streaming and self.max_samples is not None: if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.") raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.neat_packing and not self.packing:
raise ValueError("`neat_packing` requires `packing` is True.")

View File

@ -376,14 +376,21 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.use_galore and self.use_badam: if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.") raise ValueError("Cannot use GaLore with BAdam together.")
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
if self.pissa_init and self.finetuning_type != "lora":
raise ValueError("`pissa_init` is only valid for LoRA training.")
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model): if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.") raise ValueError("Cannot use PiSSA for current training stage.")
if self.train_mm_proj_only and self.finetuning_type != "full": if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.") raise ValueError("`train_mm_proj_only` is only valid for full training.")
if self.finetuning_type != "lora":
if self.loraplus_lr_ratio is not None:
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
if self.use_rslora:
raise ValueError("`use_rslora` is only valid for LoRA training.")
if self.use_dora:
raise ValueError("`use_dora` is only valid for LoRA training.")
if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")

View File

@ -233,6 +233,10 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.use_unsloth and is_deepspeed_zero3_enabled(): if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and not data_args.packing:
logger.warning("`neat_packing` requires `packing` is True. Change it to True.")
data_args.packing = True
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)

View File

@ -115,7 +115,9 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def patch_for_block_diag_attn(model_type: str) -> None: def patch_for_block_diag_attn(model_type: str) -> None:
if model_type == "falcon": if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
elif model_type == "gemma": elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data

View File

@ -79,9 +79,8 @@ def fix_valuehead_checkpoint(
if name.startswith("v_head."): if name.startswith("v_head."):
v_head_state_dict[name] = param v_head_state_dict[name] = param
else: else:
decoder_state_dict[name.replace("pretrained_model.", "")] = param decoder_state_dict[name.replace("pretrained_model.", "", count=1)] = param
os.remove(path_to_checkpoint)
model.pretrained_model.save_pretrained( model.pretrained_model.save_pretrained(
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
) )
@ -91,6 +90,7 @@ def fix_valuehead_checkpoint(
else: else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
os.remove(path_to_checkpoint)
logger.info("Value head model saved at: {}".format(output_dir)) logger.info("Value head model saved at: {}".format(output_dir))

View File

@ -17,9 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForSeq2Seq from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, split_dataset
from ...data import get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
@ -54,10 +52,13 @@ def run_sft(
if getattr(model, "is_quantized", False) and not training_args.do_train: if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
data_collator = DataCollatorForSeq2Seq( data_collator = SFTDataCollatorWith4DAttentionMask(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
) )
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer