This commit is contained in:
hiyouga
2024-08-29 20:30:18 +08:00
parent 364b757e30
commit ad72f3e065
4 changed files with 7 additions and 8 deletions

View File

@@ -75,8 +75,7 @@ def fix_valuehead_checkpoint(
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
os.remove(path_to_checkpoint)
decoder_state_dict = {}
v_head_state_dict = {}
decoder_state_dict, v_head_state_dict = {}, {}
for name, param in state_dict.items():
if name.startswith("v_head."):
v_head_state_dict[name] = param