support FlashAttention2

This commit is contained in:
hiyouga
2023-09-10 20:43:56 +08:00
parent 815b92e698
commit d8aa1404be
9 changed files with 875 additions and 115 deletions

View File

@@ -43,6 +43,10 @@ class ModelArguments:
default=None,
metadata={"help": "Adopt scaled rotary positional embeddings."}
)
flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable flash attention for faster training."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}