update info

Former-commit-id: 17c64a05796cac70fc76ed728705cd60efa41cae
This commit is contained in:
hiyouga 2023-11-07 16:28:21 +08:00
parent 1f2c56bff9
commit 53fcc531b5

View File

@ -128,8 +128,8 @@ def load_model_and_tokenizer(
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
elif getattr(config, "model_type", None) == "qwen":
logger.info("Qwen models automatically enable FlashAttention if installed.")
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
logger.info("Current model automatically enables FlashAttention if installed.")
else:
logger.warning("Current model does not support FlashAttention-2.")
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":