generalized packing & fix #6343

Former-commit-id: 2d107d3aefd5af61163056634c8b91fe3cb3e77c
This commit is contained in:
hiyouga 2024-12-17 10:26:19 +00:00
parent 4caf043cf8
commit bff1b94583
4 changed files with 10 additions and 57 deletions

View File

@ -30,7 +30,7 @@ Dependency graph:
longlora: longlora:
transformers>=4.41.2,<=4.46.1 transformers>=4.41.2,<=4.46.1
packing: packing:
transformers>=4.41.2,<=4.46.1 transformers>=4.43.0,<=4.46.1
Disable version checking: DISABLE_VERSION_CHECK=1 Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1 Enable VRAM recording: RECORD_VRAM=1

View File

@ -81,19 +81,6 @@ TRAINING_STAGES = {
STAGES_USE_PAIR_DATA = {"rm", "dpo"} STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
"cohere",
"falcon",
"gemma",
"gemma2",
"llama",
"mistral",
"phi",
"phi3",
"qwen2",
"starcoder2",
}
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>") VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")

View File

@ -44,13 +44,14 @@ import torch.nn.functional as F
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras import logging from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.packages import is_transformers_version_greater_than from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING: if is_transformers_version_greater_than("4.43.0"):
from transformers import PretrainedConfig import transformers.modeling_flash_attention_utils
if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
@ -113,45 +114,10 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
return indices, cu_seqlens, max_seqlen_in_batch return indices, cu_seqlens, max_seqlen_in_batch
def _patch_for_block_diag_attn(model_type: str) -> None: def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
if is_transformers_version_greater_than("4.43.0"):
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
return
import transformers.models
if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
elif model_type == "mistral":
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
elif model_type == "phi3":
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn: if not is_trainable or not model_args.block_diag_attn:
return return
model_type = getattr(config, "model_type", None) require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1")
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
_patch_for_block_diag_attn(model_type) logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
else:
raise ValueError("Current model does not support block diagonal attention.")

View File

@ -96,7 +96,7 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, init_kwargs) configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)
configure_visual_model(config) configure_visual_model(config)
configure_packing(config, model_args, is_trainable) configure_packing(model_args, is_trainable)
if model_args.use_cache and not is_trainable: if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True) setattr(config, "use_cache", True)