mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
update info
Former-commit-id: 17c64a05796cac70fc76ed728705cd60efa41cae
This commit is contained in:
parent
1f2c56bff9
commit
53fcc531b5
@ -128,8 +128,8 @@ def load_model_and_tokenizer(
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
elif getattr(config, "model_type", None) == "qwen":
|
||||
logger.info("Qwen models automatically enable FlashAttention if installed.")
|
||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||
else:
|
||||
logger.warning("Current model does not support FlashAttention-2.")
|
||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||
|
Loading…
x
Reference in New Issue
Block a user