From 5319447aa5f81355593f2a0a73e2fc39b5969912 Mon Sep 17 00:00:00 2001 From: ancv Date: Fri, 21 Jun 2024 00:45:06 +0700 Subject: [PATCH] move configure_packing to llamafactory.model.patcher and fix constants Former-commit-id: 770f75dc8363bfa284a72159ff8ad25ec9abe4e0 --- src/llamafactory/extras/constants.py | 2 +- src/llamafactory/hparams/data_args.py | 4 +++- src/llamafactory/model/loader.py | 5 +++-- src/llamafactory/model/model_utils/packing.py | 4 ++-- src/llamafactory/model/patcher.py | 9 +++++++-- src/llamafactory/train/sft/workflow.py | 4 ---- 6 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index d70922c1..466b1269 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -66,7 +66,7 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} -SUPPORTED_CLASS_FOR_MULTIPACK = [ +SUPPORTED_CLASS_EFFECIENT_PACKING = [ "llama", "mistral", "mixtral", diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 2081c6b4..d2d53ec8 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -86,7 +86,9 @@ class DataArguments: ) efficient_packing: Optional[bool] = field( default=None, - metadata={"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."}, + metadata={ + "help": "Whether or not to pack the sequences without cross-contamination attention for efficient training." + }, ) tokenized_path: Optional[str] = field( default=None, diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 697a04e7..026a09be 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -16,7 +16,7 @@ from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin - from ..hparams import FinetuningArguments, ModelArguments + from ..hparams import FinetuningArguments, ModelArguments, DataArguments logger = get_logger(__name__) @@ -104,6 +104,7 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig": def load_model( tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", + data_args: "DataArguments", finetuning_args: "FinetuningArguments", is_trainable: bool = False, add_valuehead: bool = False, @@ -113,7 +114,7 @@ def load_model( """ init_kwargs = _get_init_kwargs(model_args) config = load_config(model_args) - patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) + patch_config(config, tokenizer, model_args, data_args, finetuning_args, init_kwargs, is_trainable) model = None lazy_load = False diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index ce156728..606cd03b 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -19,7 +19,7 @@ from transformers.modeling_attn_mask_utils import ( from transformers.utils import is_torch_bf16_gpu_available from ...extras.logging import get_logger -from ...extras.constants import SUPPORTED_CLASS_FOR_MULTIPACK +from ...extras.constants import SUPPORTED_CLASS_EFFECIENT_PACKING if TYPE_CHECKING: from transformers import PretrainedConfig @@ -303,7 +303,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") else: attn_implementation = getattr(config, "_attn_implementation", "") - if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_MULTIPACK: + if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING: patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation) logger.info("Using packing sequences without cross-contamination attention for efficient training.") else: diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 87c92315..47591de6 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -19,13 +19,13 @@ from .model_utils.quantization import configure_quantization from .model_utils.rope import configure_rope from .model_utils.valuehead import prepare_valuehead_model from .model_utils.visual import autocast_projector_dtype, configure_visual_model - +from .model_utils.packing import configure_packing if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer from trl import AutoModelForCausalLMWithValueHead - from ..hparams import ModelArguments + from ..hparams import ModelArguments, DataArguments, FinetuningArguments logger = get_logger(__name__) @@ -40,6 +40,8 @@ def patch_config( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", + data_args: "DataArguments", + finetune_args: "FinetuningArguments", init_kwargs: Dict[str, Any], is_trainable: bool, ) -> None: @@ -81,6 +83,9 @@ def patch_config( if init_kwargs["device_map"] == "auto": init_kwargs["offload_folder"] = model_args.offload_folder + + if finetune_args.stage == "sft" and data_args.efficient_packing: + configure_packing(config, model_args) def patch_model( diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index d7c29743..f1e000bd 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -12,7 +12,6 @@ from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push from .metric import ComputeMetrics from .trainer import CustomSeq2SeqTrainer -from ...model.model_utils.packing import configure_packing if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -33,9 +32,6 @@ def run_sft( dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) - if data_args.efficient_packing: - configure_packing(model.config, model_args) - if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation