update info

Former-commit-id: 89643b8ac1e3fa8d2f29f1c88e4d4503410c0d05
This commit is contained in:
hiyouga 2023-11-07 16:28:21 +08:00
parent f7f0c3070e
commit 2084133058

View File

@ -128,8 +128,8 @@ def load_model_and_tokenizer(
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
elif getattr(config, "model_type", None) == "qwen":
logger.info("Qwen models automatically enable FlashAttention if installed.")
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
logger.info("Current model automatically enables FlashAttention if installed.")
else:
logger.warning("Current model does not support FlashAttention-2.")
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":