update flashattn, fix ppo save model

This commit is contained in:
hiyouga
2023-09-11 17:25:36 +08:00
parent b218c271ed
commit 0fbece85a7
5 changed files with 105 additions and 518 deletions

View File

@@ -132,8 +132,11 @@ def load_model_and_tokenizer(
# Set flash attention
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
from llmtuner.extras.models.flash_llama import LlamaForCausalLM
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
import transformers.models.llama.modeling_llama as LlamaModule
from llmtuner.extras.models.flash_llama import LlamaRMSNorm, LlamaAttention, _prepare_decoder_attention_mask
LlamaModule.LlamaRMSNorm = LlamaRMSNorm
LlamaModule.LlamaAttention = LlamaAttention
LlamaModule.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
if not hasattr(config, "num_key_value_heads"):
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
if getattr(config, "pretraining_tp", 1) != 1: