move configure_packing to llamafactory.model.patcher and fix constants

Former-commit-id: 770f75dc8363bfa284a72159ff8ad25ec9abe4e0
This commit is contained in:
ancv 2024-06-21 00:45:06 +07:00
parent 988231026a
commit 5319447aa5
6 changed files with 16 additions and 12 deletions

View File

@ -66,7 +66,7 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
SUPPORTED_CLASS_FOR_MULTIPACK = [ SUPPORTED_CLASS_EFFECIENT_PACKING = [
"llama", "llama",
"mistral", "mistral",
"mixtral", "mixtral",

View File

@ -86,7 +86,9 @@ class DataArguments:
) )
efficient_packing: Optional[bool] = field( efficient_packing: Optional[bool] = field(
default=None, default=None,
metadata={"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."}, metadata={
"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."
},
) )
tokenized_path: Optional[str] = field( tokenized_path: Optional[str] = field(
default=None, default=None,

View File

@ -16,7 +16,7 @@ from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments, DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -104,6 +104,7 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
def load_model( def load_model(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool = False, is_trainable: bool = False,
add_valuehead: bool = False, add_valuehead: bool = False,
@ -113,7 +114,7 @@ def load_model(
""" """
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args) config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, data_args, finetuning_args, init_kwargs, is_trainable)
model = None model = None
lazy_load = False lazy_load = False

View File

@ -19,7 +19,7 @@ from transformers.modeling_attn_mask_utils import (
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.constants import SUPPORTED_CLASS_FOR_MULTIPACK from ...extras.constants import SUPPORTED_CLASS_EFFECIENT_PACKING
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -303,7 +303,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments")
else: else:
attn_implementation = getattr(config, "_attn_implementation", "") attn_implementation = getattr(config, "_attn_implementation", "")
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_MULTIPACK: if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING:
patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation) patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation)
logger.info("Using packing sequences without cross-contamination attention for efficient training.") logger.info("Using packing sequences without cross-contamination attention for efficient training.")
else: else:

View File

@ -19,13 +19,13 @@ from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import autocast_projector_dtype, configure_visual_model from .model_utils.visual import autocast_projector_dtype, configure_visual_model
from .model_utils.packing import configure_packing
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments from ..hparams import ModelArguments, DataArguments, FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -40,6 +40,8 @@ def patch_config(
config: "PretrainedConfig", config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments",
finetune_args: "FinetuningArguments",
init_kwargs: Dict[str, Any], init_kwargs: Dict[str, Any],
is_trainable: bool, is_trainable: bool,
) -> None: ) -> None:
@ -81,6 +83,9 @@ def patch_config(
if init_kwargs["device_map"] == "auto": if init_kwargs["device_map"] == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder init_kwargs["offload_folder"] = model_args.offload_folder
if finetune_args.stage == "sft" and data_args.efficient_packing:
configure_packing(config, model_args)
def patch_model( def patch_model(

View File

@ -12,7 +12,6 @@ from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeMetrics from .metric import ComputeMetrics
from .trainer import CustomSeq2SeqTrainer from .trainer import CustomSeq2SeqTrainer
from ...model.model_utils.packing import configure_packing
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -33,9 +32,6 @@ def run_sft(
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if data_args.efficient_packing:
configure_packing(model.config, model_args)
if training_args.predict_with_generate: if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation tokenizer.padding_side = "left" # use left-padding in generation