mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
move configure_packing to llamafactory.model.patcher and fix constants
Former-commit-id: 770f75dc83
This commit is contained in:
@@ -19,7 +19,7 @@ from transformers.modeling_attn_mask_utils import (
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_MULTIPACK
|
||||
from ...extras.constants import SUPPORTED_CLASS_EFFECIENT_PACKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
@@ -303,7 +303,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments")
|
||||
else:
|
||||
attn_implementation = getattr(config, "_attn_implementation", "")
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_MULTIPACK:
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING:
|
||||
patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation)
|
||||
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user