mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
generalized packing & fix #6343
Former-commit-id: 2d107d3aefd5af61163056634c8b91fe3cb3e77c
This commit is contained in:
parent
4caf043cf8
commit
bff1b94583
@ -30,7 +30,7 @@ Dependency graph:
|
||||
longlora:
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
packing:
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
transformers>=4.43.0,<=4.46.1
|
||||
|
||||
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||
Enable VRAM recording: RECORD_VRAM=1
|
||||
|
@ -81,19 +81,6 @@ TRAINING_STAGES = {
|
||||
|
||||
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
||||
|
||||
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
||||
"cohere",
|
||||
"falcon",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"llama",
|
||||
"mistral",
|
||||
"phi",
|
||||
"phi3",
|
||||
"qwen2",
|
||||
"starcoder2",
|
||||
}
|
||||
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||
|
||||
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
|
||||
|
@ -44,13 +44,14 @@ import torch.nn.functional as F
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
@ -113,45 +114,10 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
return indices, cu_seqlens, max_seqlen_in_batch
|
||||
|
||||
|
||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
return
|
||||
|
||||
import transformers.models
|
||||
|
||||
if model_type == "cohere":
|
||||
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
|
||||
elif model_type == "gemma":
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
|
||||
elif model_type == "gemma2":
|
||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
|
||||
elif model_type == "llama":
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
|
||||
elif model_type == "mistral":
|
||||
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
|
||||
elif model_type == "phi":
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
|
||||
elif model_type == "phi3":
|
||||
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data
|
||||
|
||||
|
||||
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||
_patch_for_block_diag_attn(model_type)
|
||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
else:
|
||||
raise ValueError("Current model does not support block diagonal attention.")
|
||||
require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1")
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
|
@ -96,7 +96,7 @@ def patch_config(
|
||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
configure_visual_model(config)
|
||||
configure_packing(config, model_args, is_trainable)
|
||||
configure_packing(model_args, is_trainable)
|
||||
|
||||
if model_args.use_cache and not is_trainable:
|
||||
setattr(config, "use_cache", True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user