upgrade peft, fix #1088 #1411

Former-commit-id: aa7d104f8e050d12cb8f585bc8a52c850995500f
This commit is contained in:
hiyouga
2023-11-07 16:13:36 +08:00
parent 37a0d62a82
commit 2eb65d21ac
15 changed files with 133 additions and 99 deletions

View File

@@ -1,6 +1,9 @@
import os
import torch
from typing import TYPE_CHECKING
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from peft import (
PeftModel,
TaskType,
@@ -23,8 +26,7 @@ def init_adapter(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
is_mergeable: bool
is_trainable: bool
) -> "PreTrainedModel":
r"""
Initializes the adapters.
@@ -61,7 +63,7 @@ def init_adapter(
latest_checkpoint = None
if model_args.checkpoint_dir is not None:
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
if is_trainable and finetuning_args.resume_lora_training: # continually fine-tuning
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
@@ -92,10 +94,33 @@ def init_adapter(
modules_to_save=finetuning_args.additional_target
)
model = get_peft_model(model, lora_config)
if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923
model.base_model.peft_config = model.peft_config
if model_args.checkpoint_dir is not None:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model
def load_valuehead_params(
model: "PreTrainedModel",
model_args: "ModelArguments"
) -> None:
kwargs = {
"path_or_repo_id": model_args.reward_model,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token,
"revision": model_args.model_revision
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
raise ValueError("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)