fix packing for eager/sdpa attn

This commit is contained in:
hiyouga
2024-07-04 01:52:43 +08:00
parent 87d9b2d005
commit 6fd6aa4530
9 changed files with 51 additions and 20 deletions

View File

@@ -112,6 +112,3 @@ class DataArguments:
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.neat_packing and not self.packing:
raise ValueError("`neat_packing` requires `packing` is True.")

View File

@@ -376,14 +376,21 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.")
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
if self.pissa_init and self.finetuning_type != "lora":
raise ValueError("`pissa_init` is only valid for LoRA training.")
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")
if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.")
if self.finetuning_type != "lora":
if self.loraplus_lr_ratio is not None:
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
if self.use_rslora:
raise ValueError("`use_rslora` is only valid for LoRA training.")
if self.use_dora:
raise ValueError("`use_dora` is only valid for LoRA training.")
if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")

View File

@@ -233,6 +233,10 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and not data_args.packing:
logger.warning("`neat_packing` requires `packing` is True. Change it to True.")
data_args.packing = True
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)