From 1494fa1f18fb7c69fee8ca3cf23f782d0a3daa05 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 29 Aug 2024 20:16:01 +0800 Subject: [PATCH] fix #5305 Former-commit-id: 364b757e306f7a154359a2bc8245a839f39c4fab --- src/llamafactory/train/callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 4f34791b..c9612e6e 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -74,6 +74,7 @@ def fix_valuehead_checkpoint( path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") + os.remove(path_to_checkpoint) decoder_state_dict = {} v_head_state_dict = {} for name, param in state_dict.items(): @@ -91,7 +92,6 @@ def fix_valuehead_checkpoint( else: torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) - os.remove(path_to_checkpoint) logger.info("Value head model saved at: {}".format(output_dir))