From 1ce7b5e0f328549c35cfee08d943afd036413c05 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 28 Jun 2023 12:07:16 +0800 Subject: [PATCH] update loading logic Former-commit-id: 4d0fddba213beaa55146b047a78963d1d18185a1 --- src/utils/common.py | 48 ++++++++++++++++++++++++++++++--------------- src/utils/other.py | 13 +++++++----- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/src/utils/common.py b/src/utils/common.py index 1ff548b6..b2448150 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -85,18 +85,15 @@ def _init_adapter( if finetuning_args.finetuning_type == "freeze": logger.info("Fine-tuning method: Freeze") + for name, param in model.named_parameters(): if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers): param.requires_grad_(False) else: param.data = param.data.to(torch.float32) - if model_args.checkpoint_dir is not None: - if finetuning_args.finetuning_type != "lora": - assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." + if model_args.checkpoint_dir is not None: assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded." - else: - assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." if finetuning_args.finetuning_type == "lora": logger.info("Fine-tuning method: LoRA") @@ -205,9 +202,14 @@ def load_pretrained( if not is_trainable: # `device_map=auto` should be used for inference only config_kwargs["device_map"] = "auto" + if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": + model_to_load = model_args.checkpoint_dir[0] + else: + model_to_load = model_args.model_name_or_path + # Load and prepare pretrained models (without valuehead). model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, + model_to_load, config=config, torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16, low_cpu_mem_usage=True, @@ -268,17 +270,24 @@ def prepare_args( # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) data_args.init_for_training() - if stage != "sft" and training_args.predict_with_generate: - raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.") + assert stage == "sft" or (not training_args.predict_with_generate), \ + "`predict_with_generate` cannot be set as True at PT, RM and PPO stages." - if training_args.do_train and training_args.predict_with_generate: - raise ValueError("`predict_with_generate` cannot be set as True while training.") + assert not (training_args.do_train and training_args.predict_with_generate), \ + "`predict_with_generate` cannot be set as True while training." - if training_args.do_predict and (not training_args.predict_with_generate): - raise ValueError("Please enable `predict_with_generate` to save model predictions.") + assert (not training_args.do_predict) or training_args.predict_with_generate, \ + "Please enable `predict_with_generate` to save model predictions." - if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": - raise ValueError("Quantization is only compatible with the LoRA method.") + assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ + "Quantization is only compatible with the LoRA method." + + if model_args.checkpoint_dir is not None: + if finetuning_args.finetuning_type != "lora": + assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." + else: + assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ + "Quantized model only accepts a single checkpoint." if model_args.quantization_bit is not None and (not training_args.do_train): logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") @@ -325,8 +334,15 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun else: model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses() - if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": - raise ValueError("Quantization is only compatible with the LoRA method.") + assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ + "Quantization is only compatible with the LoRA method." + + if model_args.checkpoint_dir is not None: + if finetuning_args.finetuning_type != "lora": + assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." + else: + assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ + "Quantized model only accepts a single checkpoint." if data_args.prompt_template == "default": logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") diff --git a/src/utils/other.py b/src/utils/other.py index 3e3d25a8..21a56ea2 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -5,8 +5,8 @@ import torch import logging from typing import Dict, List, Optional -from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME -from transformers.modeling_utils import PreTrainedModel +from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint from transformers.generation.utils import LogitsProcessorList from transformers.generation.logits_process import LogitsProcessor @@ -133,11 +133,14 @@ def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get sta def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) - if not os.path.exists(weights_file): + if os.path.exists(weights_file): + model_state_dict = torch.load(weights_file, map_location="cpu") + model.load_state_dict(model_state_dict, strict=False) # skip missing keys + elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)): + load_sharded_checkpoint(model, checkpoint_dir, strict=False) + else: logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir)) return False - model_state_dict = torch.load(weights_file, map_location="cpu") - model.load_state_dict(model_state_dict, strict=False) # skip missing keys return True