mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
50 lines
2.1 KiB
Python
50 lines
2.1 KiB
Python
import os
|
|
import torch
|
|
from typing import Dict
|
|
|
|
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
|
from transformers.modeling_utils import load_sharded_checkpoint
|
|
|
|
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
|
|
from llmtuner.extras.logging import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
|
|
state_dict = model.state_dict()
|
|
filtered_state_dict = {}
|
|
|
|
for k, v in model.named_parameters():
|
|
if v.requires_grad:
|
|
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
|
|
|
return filtered_state_dict
|
|
|
|
|
|
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
|
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
|
if os.path.exists(weights_file):
|
|
model_state_dict = torch.load(weights_file, map_location="cpu")
|
|
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
|
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
|
|
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
|
|
else:
|
|
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
|
|
return False
|
|
return True
|
|
|
|
|
|
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
|
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
|
|
if not os.path.exists(valuehead_file):
|
|
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
|
return False
|
|
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
|
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
|
|
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
|
|
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
|
|
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
|
return True
|