update loading logic

Former-commit-id: 4d0fddba213beaa55146b047a78963d1d18185a1
This commit is contained in:
hiyouga 2023-06-28 12:07:16 +08:00
parent 2ff577810a
commit 1ce7b5e0f3
2 changed files with 40 additions and 21 deletions

View File

@ -85,18 +85,15 @@ def _init_adapter(
if finetuning_args.finetuning_type == "freeze":
logger.info("Fine-tuning method: Freeze")
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
param.requires_grad_(False)
else:
param.data = param.data.to(torch.float32)
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
if model_args.checkpoint_dir is not None:
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
else:
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
@ -205,9 +202,14 @@ def load_pretrained(
if not is_trainable: # `device_map=auto` should be used for inference only
config_kwargs["device_map"] = "auto"
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
# Load and prepare pretrained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
model_to_load,
config=config,
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
low_cpu_mem_usage=True,
@ -268,17 +270,24 @@ def prepare_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
if stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
assert stage == "sft" or (not training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
assert not (training_args.do_train and training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True while training."
if training_args.do_predict and (not training_args.predict_with_generate):
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
assert (not training_args.do_predict) or training_args.predict_with_generate, \
"Please enable `predict_with_generate` to save model predictions."
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
else:
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint."
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@ -325,8 +334,15 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun
else:
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
else:
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint."
if data_args.prompt_template == "default":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")

View File

@ -5,8 +5,8 @@ import torch
import logging
from typing import Dict, List, Optional
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME
from transformers.modeling_utils import PreTrainedModel
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
@ -133,11 +133,14 @@ def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get sta
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if not os.path.exists(weights_file):
if os.path.exists(weights_file):
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
else:
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
return False
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
return True