From d1e6e02461fff0adc5c37815d09dba2c20781dc3 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 1 Mar 2024 13:02:41 +0800 Subject: [PATCH] fix #2649 Former-commit-id: 4e5fae2fac85227641bd16159cf296a32e0b18b4 --- src/llmtuner/hparams/parser.py | 10 ++-- src/llmtuner/model/__init__.py | 10 +++- src/llmtuner/model/loader.py | 75 +++++++++++++++++++----------- src/llmtuner/model/patcher.py | 32 ++++++------- src/llmtuner/train/dpo/workflow.py | 13 ++---- src/llmtuner/train/ppo/workflow.py | 7 ++- src/llmtuner/train/pt/workflow.py | 5 +- src/llmtuner/train/rm/workflow.py | 15 ++---- src/llmtuner/train/sft/workflow.py | 19 +++----- 9 files changed, 99 insertions(+), 87 deletions(-) diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index 4a541a22..6b55e03d 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -181,9 +181,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: and finetuning_args.finetuning_type == "lora" ): logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") - training_args_dict = training_args.to_dict() - training_args_dict.update(dict(ddp_find_unused_parameters=False)) - training_args = Seq2SeqTrainingArguments(**training_args_dict) + training_args.ddp_find_unused_parameters = False if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: can_resume_from_checkpoint = False @@ -205,9 +203,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") if last_checkpoint is not None: - training_args_dict = training_args.to_dict() - training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint)) - training_args = Seq2SeqTrainingArguments(**training_args_dict) + training_args.resume_from_checkpoint = last_checkpoint logger.info( "Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format( training_args.resume_from_checkpoint @@ -233,7 +229,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: # Log on each process the small summary: logger.info( - "Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format( + "Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format( training_args.local_rank, training_args.device, training_args.n_gpu, diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index 7d0d15d6..933ffc5b 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -1,5 +1,11 @@ -from .loader import load_model_and_tokenizer +from .loader import load_model, load_model_and_tokenizer, load_tokenizer from .utils import dispatch_model, load_valuehead_params -__all__ = ["load_model_and_tokenizer", "dispatch_model", "load_valuehead_params"] +__all__ = [ + "load_model", + "load_model_and_tokenizer", + "load_tokenizer", + "dispatch_model", + "load_valuehead_params", +] diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 9d453637..0760e792 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from trl import AutoModelForCausalLMWithValueHead @@ -19,38 +19,48 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def load_model_and_tokenizer( - model_args: "ModelArguments", - finetuning_args: "FinetuningArguments", - is_trainable: Optional[bool] = False, - add_valuehead: Optional[bool] = False, -) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]: - r""" - Loads pretrained model and tokenizer. - - Support both training and inference. - """ - - try_download_model_from_ms(model_args) - - config_kwargs = { +def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: + return { "trust_remote_code": True, "cache_dir": model_args.cache_dir, "revision": model_args.model_revision, "token": model_args.hf_hub_token, } + +def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": + r""" + Loads pretrained tokenizer. Must before load_model. + + Note: including inplace operation of model_args. + """ + try_download_model_from_ms(model_args) + init_kwargs = _get_init_kwargs(model_args) + tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, use_fast=model_args.use_fast_tokenizer, split_special_tokens=model_args.split_special_tokens, padding_side="right", - **config_kwargs, + **init_kwargs, ) patch_tokenizer(tokenizer) + return tokenizer - config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) - patch_config(config, tokenizer, model_args, config_kwargs, is_trainable) + +def load_model( + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: Optional[bool] = False, + add_valuehead: Optional[bool] = False, +) -> "PreTrainedModel": + r""" + Loads pretrained model. Must after load_tokenizer. + """ + init_kwargs = _get_init_kwargs(model_args) + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) + patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) model = None if is_trainable and model_args.use_unsloth: @@ -76,7 +86,7 @@ def load_model_and_tokenizer( logger.warning("Unsloth does not support loading adapters.") if model is None: - model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **config_kwargs) + model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs) patch_model(model, tokenizer, model_args, is_trainable) register_autoclass(config, model, tokenizer) @@ -105,14 +115,13 @@ def load_model_and_tokenizer( model.train() trainable_params, all_param = count_parameters(model) - logger.info( - "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + if is_trainable: + param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( trainable_params, all_param, 100 * trainable_params / all_param ) - ) - - if not is_trainable: - logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.") + else: + param_stats = "all params: {:d}".format(all_param) + logger.info(param_stats) if model_args.print_param_status: for name, param in model.named_parameters(): @@ -122,4 +131,18 @@ def load_model_and_tokenizer( ) ) + return model + + +def load_model_and_tokenizer( + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: Optional[bool] = False, + add_valuehead: Optional[bool] = False, +) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]: + r""" + Loads pretrained model and tokenizer. + """ + tokenizer = load_tokenizer(model_args) + model = load_model(tokenizer, model_args, finetuning_args, is_trainable, add_valuehead) return model, tokenizer diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 054c7de7..aaedc1a8 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -102,16 +102,16 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod return samples -def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None: +def _configure_attn_implementation(model_args: "ModelArguments", init_kwargs: Dict[str, Any]) -> None: if model_args.flash_attn: if is_flash_attn2_available(): - config_kwargs["attn_implementation"] = "flash_attention_2" logger.info("Using FlashAttention-2 for faster training and inference.") + init_kwargs["attn_implementation"] = "flash_attention_2" else: logger.warning("FlashAttention2 is not installed.") - config_kwargs["attn_implementation"] = None + init_kwargs["attn_implementation"] = None else: - config_kwargs["attn_implementation"] = "eager" + init_kwargs["attn_implementation"] = "eager" def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: @@ -154,7 +154,7 @@ def _configure_quantization( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", - config_kwargs: Dict[str, Any], + init_kwargs: Dict[str, Any], ) -> None: r""" Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) @@ -187,13 +187,13 @@ def _configure_quantization( if getattr(config, "model_type", None) == "chatglm": raise ValueError("ChatGLM model is not supported.") - config_kwargs["quantization_config"] = GPTQConfig( + init_kwargs["quantization_config"] = GPTQConfig( bits=model_args.export_quantization_bit, tokenizer=tokenizer, dataset=_get_quantization_dataset(tokenizer, model_args), ) - config_kwargs["device_map"] = "auto" - config_kwargs["max_memory"] = get_max_memory() + init_kwargs["device_map"] = "auto" + init_kwargs["max_memory"] = get_max_memory() logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) elif model_args.quantization_bit is not None: # bnb @@ -202,11 +202,11 @@ def _configure_quantization( if model_args.quantization_bit == 8: require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) elif model_args.quantization_bit == 4: require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig( + init_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=model_args.compute_dtype, bnb_4bit_use_double_quant=model_args.double_quantization, @@ -262,7 +262,7 @@ def patch_config( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", - config_kwargs: Dict[str, Any], + init_kwargs: Dict[str, Any], is_trainable: bool, ) -> None: if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 @@ -272,7 +272,7 @@ def patch_config( for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: setattr(config, dtype_name, model_args.compute_dtype == dtype) - _configure_attn_implementation(model_args, config_kwargs) + _configure_attn_implementation(model_args, init_kwargs) if model_args.rope_scaling is not None: _configure_rope(config, model_args, is_trainable) @@ -280,12 +280,12 @@ def patch_config( if is_trainable and model_args.shift_attn: _configure_longlora(config) - _configure_quantization(config, tokenizer, model_args, config_kwargs) + _configure_quantization(config, tokenizer, model_args, init_kwargs) - config_kwargs["torch_dtype"] = model_args.compute_dtype + init_kwargs["torch_dtype"] = model_args.compute_dtype if not is_deepspeed_zero3_enabled(): - config_kwargs["device_map"] = {"": get_current_device()} - config_kwargs["low_cpu_mem_usage"] = True + init_kwargs["device_map"] = {"": get_current_device()} + init_kwargs["low_cpu_mem_usage"] = True def patch_model( diff --git a/src/llmtuner/train/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py index 9ea3b617..46106f41 100644 --- a/src/llmtuner/train/dpo/workflow.py +++ b/src/llmtuner/train/dpo/workflow.py @@ -2,20 +2,18 @@ from typing import TYPE_CHECKING, List, Optional -from transformers import Seq2SeqTrainingArguments - from ...data import get_dataset, split_dataset from ...extras.constants import IGNORE_INDEX from ...extras.ploting import plot_loss from ...hparams import ModelArguments -from ...model import load_model_and_tokenizer +from ...model import load_model, load_tokenizer from ...train.dpo.collator import DPODataCollatorWithPadding from ...train.dpo.trainer import CustomDPOTrainer from ...train.utils import create_modelcard_and_push, create_ref_model if TYPE_CHECKING: - from transformers import TrainerCallback + from transformers import Seq2SeqTrainingArguments, TrainerCallback from ...hparams import DataArguments, FinetuningArguments @@ -27,8 +25,9 @@ def run_dpo( finetuning_args: "FinetuningArguments", callbacks: Optional[List["TrainerCallback"]] = None, ): - model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) + tokenizer = load_tokenizer(model_args) dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = DPODataCollatorWithPadding( tokenizer=tokenizer, pad_to_multiple_of=8, @@ -42,9 +41,7 @@ def run_dpo( ref_model = create_ref_model(model_args, finetuning_args) # Update arguments - training_args_dict = training_args.to_dict() - training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset - training_args = Seq2SeqTrainingArguments(**training_args_dict) + training_args.remove_unused_columns = False # important for pairwise dataset # Initialize our Trainer trainer = CustomDPOTrainer( diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py index 50a0e1d0..64333359 100644 --- a/src/llmtuner/train/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -12,7 +12,7 @@ from ...data import get_dataset from ...extras.callbacks import FixValueHeadModelCallback from ...extras.misc import fix_valuehead_checkpoint from ...extras.ploting import plot_loss -from ...model import load_model_and_tokenizer +from ...model import load_model, load_tokenizer from ...train.ppo.trainer import CustomPPOTrainer from ...train.utils import create_ref_model, create_reward_model @@ -31,10 +31,9 @@ def run_ppo( generating_args: "GeneratingArguments", callbacks: Optional[List["TrainerCallback"]] = None, ): - model, tokenizer = load_model_and_tokenizer( - model_args, finetuning_args, training_args.do_train, add_valuehead=True - ) + tokenizer = load_tokenizer(model_args) dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo") + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training data_collator = DataCollatorWithPadding(tokenizer=tokenizer) diff --git a/src/llmtuner/train/pt/workflow.py b/src/llmtuner/train/pt/workflow.py index cb91cf44..3f98a006 100644 --- a/src/llmtuner/train/pt/workflow.py +++ b/src/llmtuner/train/pt/workflow.py @@ -7,7 +7,7 @@ from transformers import DataCollatorForLanguageModeling, Trainer from ...data import get_dataset, split_dataset from ...extras.ploting import plot_loss -from ...model import load_model_and_tokenizer +from ...model import load_model, load_tokenizer from ...train.utils import create_modelcard_and_push @@ -24,8 +24,9 @@ def run_pt( finetuning_args: "FinetuningArguments", callbacks: Optional[List["TrainerCallback"]] = None, ): - model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) + tokenizer = load_tokenizer(model_args) dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt") + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Initialize our Trainer diff --git a/src/llmtuner/train/rm/workflow.py b/src/llmtuner/train/rm/workflow.py index 0ec9d9de..899e55bd 100644 --- a/src/llmtuner/train/rm/workflow.py +++ b/src/llmtuner/train/rm/workflow.py @@ -2,13 +2,11 @@ from typing import TYPE_CHECKING, List, Optional -from transformers import Seq2SeqTrainingArguments - from ...data import get_dataset, split_dataset from ...extras.callbacks import FixValueHeadModelCallback from ...extras.misc import fix_valuehead_checkpoint from ...extras.ploting import plot_loss -from ...model import load_model_and_tokenizer +from ...model import load_model, load_tokenizer from ...train.rm.collator import PairwiseDataCollatorWithPadding from ...train.rm.metric import compute_accuracy from ...train.rm.trainer import PairwiseTrainer @@ -16,7 +14,7 @@ from ...train.utils import create_modelcard_and_push if TYPE_CHECKING: - from transformers import TrainerCallback + from transformers import Seq2SeqTrainingArguments, TrainerCallback from ...hparams import DataArguments, FinetuningArguments, ModelArguments @@ -28,16 +26,13 @@ def run_rm( finetuning_args: "FinetuningArguments", callbacks: Optional[List["TrainerCallback"]] = None, ): - model, tokenizer = load_model_and_tokenizer( - model_args, finetuning_args, training_args.do_train, add_valuehead=True - ) + tokenizer = load_tokenizer(model_args) dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) # Update arguments - training_args_dict = training_args.to_dict() - training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset - training_args = Seq2SeqTrainingArguments(**training_args_dict) + training_args.remove_unused_columns = False # important for pairwise dataset # Initialize our Trainer trainer = PairwiseTrainer( diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index 75b4d3d1..ca052438 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -2,20 +2,20 @@ from typing import TYPE_CHECKING, List, Optional -from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments +from transformers import DataCollatorForSeq2Seq from ...data import get_dataset, split_dataset from ...extras.constants import IGNORE_INDEX from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss -from ...model import load_model_and_tokenizer +from ...model import load_model, load_tokenizer from ...train.sft.metric import ComputeMetrics from ...train.sft.trainer import CustomSeq2SeqTrainer from ...train.utils import create_modelcard_and_push if TYPE_CHECKING: - from transformers import TrainerCallback + from transformers import Seq2SeqTrainingArguments, TrainerCallback from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments @@ -28,8 +28,9 @@ def run_sft( generating_args: "GeneratingArguments", callbacks: Optional[List["TrainerCallback"]] = None, ): - model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) + tokenizer = load_tokenizer(model_args) dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft") + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation @@ -44,14 +45,8 @@ def run_sft( ) # Override the decoding parameters of Seq2SeqTrainer - training_args_dict = training_args.to_dict() - training_args_dict.update( - dict( - generation_max_length=training_args.generation_max_length or data_args.cutoff_len, - generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams, - ) - ) - training_args = Seq2SeqTrainingArguments(**training_args_dict) + training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len + training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams # Initialize our Trainer trainer = CustomSeq2SeqTrainer(