From 0b26011181f8e221b8bd6c9d97ff25f8bbf486fc Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sat, 13 Jul 2024 23:33:45 +0800 Subject: [PATCH] fix gemma2 attention Former-commit-id: 2f6af73da28c4f8321b625fd09ddec8bd4977b08 --- src/llamafactory/__init__.py | 24 ++++++++++++++++++- src/llamafactory/data/collator.py | 7 +++--- .../model/model_utils/attention.py | 21 +++++++++------- .../model/model_utils/longlora.py | 2 +- src/llamafactory/model/model_utils/packing.py | 20 ++++++++-------- .../model/model_utils/quantization.py | 1 - src/llamafactory/model/patcher.py | 4 ++++ 7 files changed, 53 insertions(+), 26 deletions(-) diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index 9d732777..5df2cbfa 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -12,7 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Level: api, webui > chat, eval, train > data, model > hparams > extras +r""" +Efficient fine-tuning of large language models. + +Level: + api, webui > chat, eval, train > data, model > hparams > extras + +Dependency graph: + main: + transformers>=4.41.2 + datasets>=2.16.0 + accelerate>=0.30.1 + peft>=0.11.1 + trl>=0.8.6 + attention: + transformers>=4.42.4 (gemma+fa2) + longlora: + transformers>=4.41.2,<=4.42.4 + packing: + transformers>=4.41.2,<=4.42.4 + patcher: + transformers==4.41.2 (chatglm) +""" + from .cli import VERSION diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index c878dcd3..a603a7e8 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -28,11 +28,10 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. e.g. - ``` + ```python + # input [[1, 1, 2, 2, 2, 0]] - ``` - -> - ``` + # output [ [ [ diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 4bed7e21..7dee827c 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available +from transformers.utils.versions import require_version from ...extras.logging import get_logger @@ -31,15 +32,17 @@ logger = get_logger(__name__) def configure_attn_implementation( config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool ) -> None: - if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention - if model_args.flash_attn == "auto": - logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.") - model_args.flash_attn = "disabled" - elif model_args.flash_attn != "disabled": - logger.warning( - "Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. " - "Will proceed at your own risk.".format(model_args.flash_attn) - ) + if getattr(config, "model_type", None) == "gemma2" and is_trainable: + if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": + if is_flash_attn_2_available(): + require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") + logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") + model_args.flash_attn = "fa2" + else: + logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.") + model_args.flash_attn = "disabled" + elif model_args.flash_attn == "sdpa": + raise ValueError("Gemma-2 should use soft-capping attention, while the SDPA attention is not compatible.") if model_args.flash_attn == "auto": return diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index 5d52c475..b8e32903 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -326,7 +326,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers>=4.41.2,<=4.42.3", "To fix: pip install transformers>=4.41.2,<=4.42.3") + require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 07405db5..674e0b4a 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -42,6 +42,7 @@ from typing import TYPE_CHECKING, Tuple import torch import torch.nn.functional as F import transformers.models +from transformers.utils.versions import require_version from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN from ...extras.logging import get_logger @@ -61,14 +62,13 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": Gets the sequnce lengths in the current batch. e.g. - ``` + ```python + # input [ [1, 1, 2, 2, 2, 0], [1, 2, 2, 3, 3, 3], ] - ``` - -> - ``` + # output [2, 3, 1, 2, 3] ``` """ @@ -94,14 +94,13 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor max_seqlen_in_batch: the largest seqlen in the current batch. e.g. - ``` + ```python + # input [ [1, 1, 2, 2, 2, 0], [1, 2, 2, 3, 3, 3], ] - ``` - -> - ``` + # output [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11] [0, 2, 5, 6, 8, 11] 3 @@ -114,7 +113,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor return indices, cu_seqlens, max_seqlen_in_batch -def patch_for_block_diag_attn(model_type: str) -> None: +def _patch_for_block_diag_attn(model_type: str) -> None: + require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4") if model_type == "cohere": transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data elif model_type == "falcon": @@ -143,7 +143,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", model_type = getattr(config, "model_type", None) if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: - patch_for_block_diag_attn(model_type) + _patch_for_block_diag_attn(model_type) logger.info("Using block diagonal attention for sequence packing without cross-attention.") else: raise ValueError("Current model does not support block diagonal attention.") diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 317646e0..451abee0 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -126,7 +126,6 @@ def configure_quantization( require_version("autoawq", "To fix: pip install autoawq") if quant_method == QuantizationMethod.AQLM: - require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") quantization_config["bits"] = 2 diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index a99d38e0..cc233311 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -21,6 +21,7 @@ from peft import PeftModel from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled +from transformers.utils.versions import require_version from ..extras.logging import get_logger from ..extras.misc import infer_optim_dtype @@ -88,6 +89,9 @@ def patch_config( if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn + if getattr(config, "model_type", None) == "chatglm": + require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2") + # deepspeed zero3 is not compatible with low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())