From 973aac32037cd1f1986cadd2bb80ef2cf0fc386e Mon Sep 17 00:00:00 2001 From: "-.-" Date: Wed, 10 Jul 2024 12:05:51 +0800 Subject: [PATCH 1/2] fix src/llamafactory/train/callbacks.py Former-commit-id: cff89a2e8907f3fe89406006105cb6494e2ee993 --- src/llamafactory/train/callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 97eb6d1c..e1e3de99 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -79,7 +79,7 @@ def fix_valuehead_checkpoint( if name.startswith("v_head."): v_head_state_dict[name] = param else: - decoder_state_dict[name.replace("pretrained_model.", "", count=1)] = param + decoder_state_dict[name.replace("pretrained_model.", "",1)] = param model.pretrained_model.save_pretrained( output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization From 4edd7c3529a2117be0dc91b504efe3269f92fe6c Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 10 Jul 2024 13:32:20 +0800 Subject: [PATCH 2/2] Update callbacks.py Former-commit-id: 39cd89ce17220dc50c8331299ae5af230fe40cc9 --- src/llamafactory/train/callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index e1e3de99..e7ce09a2 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -79,7 +79,7 @@ def fix_valuehead_checkpoint( if name.startswith("v_head."): v_head_state_dict[name] = param else: - decoder_state_dict[name.replace("pretrained_model.", "",1)] = param + decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param model.pretrained_model.save_pretrained( output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization