mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-17 00:58:10 +08:00
22 lines
991 B
Python
22 lines
991 B
Python
import os
|
|
import torch
|
|
from transformers.trainer import WEIGHTS_NAME
|
|
|
|
from llmtuner.extras.logging import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
|
vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
|
if not os.path.exists(vhead_file):
|
|
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
|
return False
|
|
vhead_params = torch.load(vhead_file, map_location="cpu")
|
|
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
|
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
|
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
|
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
|
return True
|