support full-parameter PPO

This commit is contained in:
hiyouga
2023-11-16 02:08:04 +08:00
parent 8350bcf85d
commit ce78303600
20 changed files with 288 additions and 145 deletions

View File

@@ -1,8 +1,5 @@
import torch
from typing import TYPE_CHECKING
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from llmtuner.extras.logging import get_logger
@@ -98,30 +95,3 @@ def init_adapter(
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model
def load_valuehead_params(
model: "PreTrainedModel",
model_args: "ModelArguments"
) -> bool:
kwargs = {
"path_or_repo_id": model_args.reward_model,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token,
"revision": model_args.model_revision
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
return False
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True