From cf149bf43c46fec335a5bbaa54c8621e53553783 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 20 Mar 2024 17:56:33 +0800 Subject: [PATCH] fix #2346 Former-commit-id: 7b8f5029018f0481f7da83cc5ee4408d95c9beb2 --- src/llmtuner/model/patcher.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 949e9e16..e09d17a5 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -103,14 +103,19 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod return samples -def _configure_attn_implementation(model_args: "ModelArguments", init_kwargs: Dict[str, Any]) -> None: +def _configure_attn_implementation( + config: "PretrainedConfig", model_args: "ModelArguments", init_kwargs: Dict[str, Any] +) -> None: if model_args.flash_attn: - if is_flash_attn2_available(): - logger.info("Using FlashAttention-2 for faster training and inference.") - init_kwargs["attn_implementation"] = "flash_attention_2" - else: + if not is_flash_attn2_available(): logger.warning("FlashAttention2 is not installed.") - init_kwargs["attn_implementation"] = None + return + + logger.info("Using FlashAttention-2 for faster training and inference.") + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + setattr(config, "attn_implementation", "flash_attention_2") + else: + init_kwargs["attn_implementation"] = "flash_attention_2" else: init_kwargs["attn_implementation"] = "eager" @@ -283,7 +288,7 @@ def patch_config( for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: setattr(config, dtype_name, model_args.compute_dtype == dtype) - _configure_attn_implementation(model_args, init_kwargs) + _configure_attn_implementation(config, model_args, init_kwargs) _configure_rope(config, model_args, is_trainable) _configure_longlora(config, model_args, is_trainable) _configure_quantization(config, tokenizer, model_args, init_kwargs)