fix gemma2 attention

Former-commit-id: 2f6af73da28c4f8321b625fd09ddec8bd4977b08
This commit is contained in:
hiyouga 2024-07-13 23:33:45 +08:00
parent fb387ae1c3
commit 0b26011181
7 changed files with 53 additions and 26 deletions

View File

@ -12,7 +12,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Level: api, webui > chat, eval, train > data, model > hparams > extras r"""
Efficient fine-tuning of large language models.
Level:
api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
main:
transformers>=4.41.2
datasets>=2.16.0
accelerate>=0.30.1
peft>=0.11.1
trl>=0.8.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<=4.42.4
packing:
transformers>=4.41.2,<=4.42.4
patcher:
transformers==4.41.2 (chatglm)
"""
from .cli import VERSION from .cli import VERSION

View File

@ -28,11 +28,10 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g. e.g.
``` ```python
# input
[[1, 1, 2, 2, 2, 0]] [[1, 1, 2, 2, 2, 0]]
``` # output
->
```
[ [
[ [
[ [

View File

@ -15,6 +15,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version
from ...extras.logging import get_logger from ...extras.logging import get_logger
@ -31,15 +32,17 @@ logger = get_logger(__name__)
def configure_attn_implementation( def configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None: ) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == "auto": if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.") if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
else:
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
model_args.flash_attn = "disabled" model_args.flash_attn = "disabled"
elif model_args.flash_attn != "disabled": elif model_args.flash_attn == "sdpa":
logger.warning( raise ValueError("Gemma-2 should use soft-capping attention, while the SDPA attention is not compatible.")
"Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. "
"Will proceed at your own risk.".format(model_args.flash_attn)
)
if model_args.flash_attn == "auto": if model_args.flash_attn == "auto":
return return

View File

@ -326,7 +326,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.42.3", "To fix: pip install transformers>=4.41.2,<=4.42.3") require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@ -42,6 +42,7 @@ from typing import TYPE_CHECKING, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers.models import transformers.models
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger from ...extras.logging import get_logger
@ -61,14 +62,13 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
Gets the sequnce lengths in the current batch. Gets the sequnce lengths in the current batch.
e.g. e.g.
``` ```python
# input
[ [
[1, 1, 2, 2, 2, 0], [1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3], [1, 2, 2, 3, 3, 3],
] ]
``` # output
->
```
[2, 3, 1, 2, 3] [2, 3, 1, 2, 3]
``` ```
""" """
@ -94,14 +94,13 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
max_seqlen_in_batch: the largest seqlen in the current batch. max_seqlen_in_batch: the largest seqlen in the current batch.
e.g. e.g.
``` ```python
# input
[ [
[1, 1, 2, 2, 2, 0], [1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3], [1, 2, 2, 3, 3, 3],
] ]
``` # output
->
```
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11] [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
[0, 2, 5, 6, 8, 11] [0, 2, 5, 6, 8, 11]
3 3
@ -114,7 +113,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
return indices, cu_seqlens, max_seqlen_in_batch return indices, cu_seqlens, max_seqlen_in_batch
def patch_for_block_diag_attn(model_type: str) -> None: def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
if model_type == "cohere": if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon": elif model_type == "falcon":
@ -143,7 +143,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
patch_for_block_diag_attn(model_type) _patch_for_block_diag_attn(model_type)
logger.info("Using block diagonal attention for sequence packing without cross-attention.") logger.info("Using block diagonal attention for sequence packing without cross-attention.")
else: else:
raise ValueError("Current model does not support block diagonal attention.") raise ValueError("Current model does not support block diagonal attention.")

View File

@ -126,7 +126,6 @@ def configure_quantization(
require_version("autoawq", "To fix: pip install autoawq") require_version("autoawq", "To fix: pip install autoawq")
if quant_method == QuantizationMethod.AQLM: if quant_method == QuantizationMethod.AQLM:
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
quantization_config["bits"] = 2 quantization_config["bits"] = 2

View File

@ -21,6 +21,7 @@ from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import infer_optim_dtype from ..extras.misc import infer_optim_dtype
@ -88,6 +89,9 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "chatglm":
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
# deepspeed zero3 is not compatible with low_cpu_mem_usage # deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())