Former-commit-id: 364b757e306f7a154359a2bc8245a839f39c4fab
This commit is contained in:
hiyouga 2024-08-29 20:16:01 +08:00
parent efd60f0306
commit 1494fa1f18

View File

@ -74,6 +74,7 @@ def fix_valuehead_checkpoint(
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
os.remove(path_to_checkpoint)
decoder_state_dict = {}
v_head_state_dict = {}
for name, param in state_dict.items():
@ -91,7 +92,6 @@ def fix_valuehead_checkpoint(
else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
os.remove(path_to_checkpoint)
logger.info("Value head model saved at: {}".format(output_dir))