mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
update info
Former-commit-id: 17c64a05796cac70fc76ed728705cd60efa41cae
This commit is contained in:
parent
1f2c56bff9
commit
53fcc531b5
@ -128,8 +128,8 @@ def load_model_and_tokenizer(
|
|||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
elif getattr(config, "model_type", None) == "qwen":
|
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||||
logger.info("Qwen models automatically enable FlashAttention if installed.")
|
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support FlashAttention-2.")
|
logger.warning("Current model does not support FlashAttention-2.")
|
||||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user