diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 9021d277..80d9d4b8 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -32,8 +32,14 @@ def configure_attn_implementation( config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool ) -> None: if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention - logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.") - model_args.flash_attn = "disabled" + if model_args.flash_attn == "auto": + logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.") + model_args.flash_attn = "disabled" + else: + logger.warning( + "Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. " + "Will proceed at your own risk.".format(model_args.flash_attn) + ) if model_args.flash_attn == "auto": return diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index dc9c981e..4d024278 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -1,7 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. -# -# This code is inspired by the HuggingFace's transformers library. -# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py +# Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.