update flashattn, fix ppo save model

Former-commit-id: 0b08bc3dac246d4aa3f89afb7172529dcad9c39f
This commit is contained in:
hiyouga
2023-09-11 17:25:36 +08:00
parent a09a7b650d
commit 42e0b30476
5 changed files with 105 additions and 518 deletions

View File

@@ -132,8 +132,11 @@ def load_model_and_tokenizer(
# Set flash attention
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
from llmtuner.extras.models.flash_llama import LlamaForCausalLM
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
import transformers.models.llama.modeling_llama as LlamaModule
from llmtuner.extras.models.flash_llama import LlamaRMSNorm, LlamaAttention, _prepare_decoder_attention_mask
LlamaModule.LlamaRMSNorm = LlamaRMSNorm
LlamaModule.LlamaAttention = LlamaAttention
LlamaModule.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
if not hasattr(config, "num_key_value_heads"):
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
if getattr(config, "pretraining_tp", 1) != 1: