support FlashAttention2

This commit is contained in:
hiyouga
2023-09-10 20:43:56 +08:00
parent 815b92e698
commit d8aa1404be
9 changed files with 875 additions and 115 deletions

View File

@@ -4,6 +4,7 @@ import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
import transformers
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@@ -84,7 +85,8 @@ def load_model_and_tokenizer(
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
# Fix config (for Qwen)
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"):
if model_args.compute_dtype == torch.bfloat16:
setattr(config, "bf16", True)
else:
@@ -105,6 +107,7 @@ def load_model_and_tokenizer(
if is_trainable:
if model_args.rope_scaling == "dynamic":
assert not model_args.flash_attn, "Flash attention does not support dynamic rope scaling."
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
@@ -127,6 +130,15 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support RoPE scaling.")
# Set flash attention
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
from llmtuner.extras.models.flash_llama import LlamaForCausalLM
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
if not hasattr(config, "num_key_value_heads"):
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
if getattr(config, "pretraining_tp", 1) != 1:
setattr(config, "pretraining_tp", 1)
# Quantization configurations (using bitsandbytes library).
is_mergeable = True
if model_args.quantization_bit is not None: