mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
update flashattn, fix ppo save model
This commit is contained in:
@@ -132,8 +132,11 @@ def load_model_and_tokenizer(
|
||||
|
||||
# Set flash attention
|
||||
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
|
||||
from llmtuner.extras.models.flash_llama import LlamaForCausalLM
|
||||
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
|
||||
import transformers.models.llama.modeling_llama as LlamaModule
|
||||
from llmtuner.extras.models.flash_llama import LlamaRMSNorm, LlamaAttention, _prepare_decoder_attention_mask
|
||||
LlamaModule.LlamaRMSNorm = LlamaRMSNorm
|
||||
LlamaModule.LlamaAttention = LlamaAttention
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
|
||||
if not hasattr(config, "num_key_value_heads"):
|
||||
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
|
||||
if getattr(config, "pretraining_tp", 1) != 1:
|
||||
|
||||
Reference in New Issue
Block a user