mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
Merge pull request #6362 from hiyouga/hiyouga/mllm_packing
[model] generalized packing Former-commit-id: 9708a39179d7872ff2039086fcadb021265974cc
This commit is contained in:
commit
a8a990a9a7
@ -30,7 +30,7 @@ Dependency graph:
|
|||||||
longlora:
|
longlora:
|
||||||
transformers>=4.41.2,<=4.46.1
|
transformers>=4.41.2,<=4.46.1
|
||||||
packing:
|
packing:
|
||||||
transformers>=4.41.2,<=4.46.1
|
transformers>=4.43.0,<=4.46.1
|
||||||
|
|
||||||
Disable version checking: DISABLE_VERSION_CHECK=1
|
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||||
Enable VRAM recording: RECORD_VRAM=1
|
Enable VRAM recording: RECORD_VRAM=1
|
||||||
|
@ -81,19 +81,6 @@ TRAINING_STAGES = {
|
|||||||
|
|
||||||
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
||||||
|
|
||||||
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
|
||||||
"cohere",
|
|
||||||
"falcon",
|
|
||||||
"gemma",
|
|
||||||
"gemma2",
|
|
||||||
"llama",
|
|
||||||
"mistral",
|
|
||||||
"phi",
|
|
||||||
"phi3",
|
|
||||||
"qwen2",
|
|
||||||
"starcoder2",
|
|
||||||
}
|
|
||||||
|
|
||||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||||
|
|
||||||
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
|
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
|
||||||
|
@ -44,13 +44,14 @@ import torch.nn.functional as F
|
|||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
|
||||||
from ...extras.packages import is_transformers_version_greater_than
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if is_transformers_version_greater_than("4.43.0"):
|
||||||
from transformers import PretrainedConfig
|
import transformers.modeling_flash_attention_utils
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
@ -113,45 +114,10 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
|||||||
return indices, cu_seqlens, max_seqlen_in_batch
|
return indices, cu_seqlens, max_seqlen_in_batch
|
||||||
|
|
||||||
|
|
||||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
|
||||||
if is_transformers_version_greater_than("4.43.0"):
|
|
||||||
import transformers.modeling_flash_attention_utils
|
|
||||||
|
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
|
||||||
return
|
|
||||||
|
|
||||||
import transformers.models
|
|
||||||
|
|
||||||
if model_type == "cohere":
|
|
||||||
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
|
||||||
elif model_type == "falcon":
|
|
||||||
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
|
|
||||||
elif model_type == "gemma":
|
|
||||||
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
|
|
||||||
elif model_type == "gemma2":
|
|
||||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
|
|
||||||
elif model_type == "llama":
|
|
||||||
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
|
|
||||||
elif model_type == "mistral":
|
|
||||||
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
|
|
||||||
elif model_type == "phi":
|
|
||||||
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
|
|
||||||
elif model_type == "phi3":
|
|
||||||
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
|
|
||||||
elif model_type == "qwen2":
|
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
|
|
||||||
elif model_type == "starcoder2":
|
|
||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data
|
|
||||||
|
|
||||||
|
|
||||||
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
|
||||||
if not is_trainable or not model_args.block_diag_attn:
|
if not is_trainable or not model_args.block_diag_attn:
|
||||||
return
|
return
|
||||||
|
|
||||||
model_type = getattr(config, "model_type", None)
|
require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1")
|
||||||
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||||
_patch_for_block_diag_attn(model_type)
|
|
||||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||||
else:
|
|
||||||
raise ValueError("Current model does not support block diagonal attention.")
|
|
||||||
|
@ -96,7 +96,7 @@ def patch_config(
|
|||||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||||
configure_moe(config, model_args, is_trainable)
|
configure_moe(config, model_args, is_trainable)
|
||||||
configure_visual_model(config)
|
configure_visual_model(config)
|
||||||
configure_packing(config, model_args, is_trainable)
|
configure_packing(model_args, is_trainable)
|
||||||
|
|
||||||
if model_args.use_cache and not is_trainable:
|
if model_args.use_cache and not is_trainable:
|
||||||
setattr(config, "use_cache", True)
|
setattr(config, "use_cache", True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user