mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
support FlashAttention2
This commit is contained in:
@@ -4,6 +4,7 @@ import torch
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -84,7 +85,8 @@ def load_model_and_tokenizer(
|
||||
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
|
||||
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
|
||||
# Fix config (for Qwen)
|
||||
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"):
|
||||
if model_args.compute_dtype == torch.bfloat16:
|
||||
setattr(config, "bf16", True)
|
||||
else:
|
||||
@@ -105,6 +107,7 @@ def load_model_and_tokenizer(
|
||||
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
assert not model_args.flash_attn, "Flash attention does not support dynamic rope scaling."
|
||||
logger.warning(
|
||||
"Dynamic NTK may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
@@ -127,6 +130,15 @@ def load_model_and_tokenizer(
|
||||
else:
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
|
||||
# Set flash attention
|
||||
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
|
||||
from llmtuner.extras.models.flash_llama import LlamaForCausalLM
|
||||
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
|
||||
if not hasattr(config, "num_key_value_heads"):
|
||||
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
|
||||
if getattr(config, "pretraining_tp", 1) != 1:
|
||||
setattr(config, "pretraining_tp", 1)
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
is_mergeable = True
|
||||
if model_args.quantization_bit is not None:
|
||||
|
||||
Reference in New Issue
Block a user