mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
parent
efd60f0306
commit
1494fa1f18
@ -74,6 +74,7 @@ def fix_valuehead_checkpoint(
|
|||||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||||
|
|
||||||
|
os.remove(path_to_checkpoint)
|
||||||
decoder_state_dict = {}
|
decoder_state_dict = {}
|
||||||
v_head_state_dict = {}
|
v_head_state_dict = {}
|
||||||
for name, param in state_dict.items():
|
for name, param in state_dict.items():
|
||||||
@ -91,7 +92,6 @@ def fix_valuehead_checkpoint(
|
|||||||
else:
|
else:
|
||||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||||
|
|
||||||
os.remove(path_to_checkpoint)
|
|
||||||
logger.info("Value head model saved at: {}".format(output_dir))
|
logger.info("Value head model saved at: {}".format(output_dir))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user