mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
fix gemma2 attention
Former-commit-id: 2f6af73da28c4f8321b625fd09ddec8bd4977b08
This commit is contained in:
parent
fb387ae1c3
commit
0b26011181
@ -12,7 +12,29 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
r"""
|
||||||
|
Efficient fine-tuning of large language models.
|
||||||
|
|
||||||
|
Level:
|
||||||
|
api, webui > chat, eval, train > data, model > hparams > extras
|
||||||
|
|
||||||
|
Dependency graph:
|
||||||
|
main:
|
||||||
|
transformers>=4.41.2
|
||||||
|
datasets>=2.16.0
|
||||||
|
accelerate>=0.30.1
|
||||||
|
peft>=0.11.1
|
||||||
|
trl>=0.8.6
|
||||||
|
attention:
|
||||||
|
transformers>=4.42.4 (gemma+fa2)
|
||||||
|
longlora:
|
||||||
|
transformers>=4.41.2,<=4.42.4
|
||||||
|
packing:
|
||||||
|
transformers>=4.41.2,<=4.42.4
|
||||||
|
patcher:
|
||||||
|
transformers==4.41.2 (chatglm)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
from .cli import VERSION
|
from .cli import VERSION
|
||||||
|
|
||||||
|
@ -28,11 +28,10 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
|||||||
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||||
|
|
||||||
e.g.
|
e.g.
|
||||||
```
|
```python
|
||||||
|
# input
|
||||||
[[1, 1, 2, 2, 2, 0]]
|
[[1, 1, 2, 2, 2, 0]]
|
||||||
```
|
# output
|
||||||
->
|
|
||||||
```
|
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
@ -31,15 +32,17 @@ logger = get_logger(__name__)
|
|||||||
def configure_attn_implementation(
|
def configure_attn_implementation(
|
||||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
|
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
|
||||||
if model_args.flash_attn == "auto":
|
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
|
||||||
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
|
if is_flash_attn_2_available():
|
||||||
|
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||||
|
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||||
|
model_args.flash_attn = "fa2"
|
||||||
|
else:
|
||||||
|
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
|
||||||
model_args.flash_attn = "disabled"
|
model_args.flash_attn = "disabled"
|
||||||
elif model_args.flash_attn != "disabled":
|
elif model_args.flash_attn == "sdpa":
|
||||||
logger.warning(
|
raise ValueError("Gemma-2 should use soft-capping attention, while the SDPA attention is not compatible.")
|
||||||
"Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. "
|
|
||||||
"Will proceed at your own risk.".format(model_args.flash_attn)
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_args.flash_attn == "auto":
|
if model_args.flash_attn == "auto":
|
||||||
return
|
return
|
||||||
|
@ -326,7 +326,7 @@ def llama_sdpa_attention_forward(
|
|||||||
|
|
||||||
|
|
||||||
def _apply_llama_patch() -> None:
|
def _apply_llama_patch() -> None:
|
||||||
require_version("transformers>=4.41.2,<=4.42.3", "To fix: pip install transformers>=4.41.2,<=4.42.3")
|
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
||||||
LlamaAttention.forward = llama_attention_forward
|
LlamaAttention.forward = llama_attention_forward
|
||||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||||
|
@ -42,6 +42,7 @@ from typing import TYPE_CHECKING, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import transformers.models
|
import transformers.models
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
@ -61,14 +62,13 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
|||||||
Gets the sequnce lengths in the current batch.
|
Gets the sequnce lengths in the current batch.
|
||||||
|
|
||||||
e.g.
|
e.g.
|
||||||
```
|
```python
|
||||||
|
# input
|
||||||
[
|
[
|
||||||
[1, 1, 2, 2, 2, 0],
|
[1, 1, 2, 2, 2, 0],
|
||||||
[1, 2, 2, 3, 3, 3],
|
[1, 2, 2, 3, 3, 3],
|
||||||
]
|
]
|
||||||
```
|
# output
|
||||||
->
|
|
||||||
```
|
|
||||||
[2, 3, 1, 2, 3]
|
[2, 3, 1, 2, 3]
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@ -94,14 +94,13 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
|||||||
max_seqlen_in_batch: the largest seqlen in the current batch.
|
max_seqlen_in_batch: the largest seqlen in the current batch.
|
||||||
|
|
||||||
e.g.
|
e.g.
|
||||||
```
|
```python
|
||||||
|
# input
|
||||||
[
|
[
|
||||||
[1, 1, 2, 2, 2, 0],
|
[1, 1, 2, 2, 2, 0],
|
||||||
[1, 2, 2, 3, 3, 3],
|
[1, 2, 2, 3, 3, 3],
|
||||||
]
|
]
|
||||||
```
|
# output
|
||||||
->
|
|
||||||
```
|
|
||||||
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
|
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
|
||||||
[0, 2, 5, 6, 8, 11]
|
[0, 2, 5, 6, 8, 11]
|
||||||
3
|
3
|
||||||
@ -114,7 +113,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
|||||||
return indices, cu_seqlens, max_seqlen_in_batch
|
return indices, cu_seqlens, max_seqlen_in_batch
|
||||||
|
|
||||||
|
|
||||||
def patch_for_block_diag_attn(model_type: str) -> None:
|
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||||
|
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
||||||
if model_type == "cohere":
|
if model_type == "cohere":
|
||||||
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
||||||
elif model_type == "falcon":
|
elif model_type == "falcon":
|
||||||
@ -143,7 +143,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
|
|||||||
|
|
||||||
model_type = getattr(config, "model_type", None)
|
model_type = getattr(config, "model_type", None)
|
||||||
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||||
patch_for_block_diag_attn(model_type)
|
_patch_for_block_diag_attn(model_type)
|
||||||
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
|
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Current model does not support block diagonal attention.")
|
raise ValueError("Current model does not support block diagonal attention.")
|
||||||
|
@ -126,7 +126,6 @@ def configure_quantization(
|
|||||||
require_version("autoawq", "To fix: pip install autoawq")
|
require_version("autoawq", "To fix: pip install autoawq")
|
||||||
|
|
||||||
if quant_method == QuantizationMethod.AQLM:
|
if quant_method == QuantizationMethod.AQLM:
|
||||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
|
||||||
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
|
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
|
||||||
quantization_config["bits"] = 2
|
quantization_config["bits"] = 2
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ from peft import PeftModel
|
|||||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import infer_optim_dtype
|
from ..extras.misc import infer_optim_dtype
|
||||||
@ -88,6 +89,9 @@ def patch_config(
|
|||||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) == "chatglm":
|
||||||
|
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
|
||||||
|
|
||||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user