Merge pull request #4746 from yzoaim/fix

fix src/llamafactory/train/callbacks.py

Former-commit-id: 40c3b88b68b205e4124a9704d73500e3c404364d
This commit is contained in:
hoshi-hiyouga 2024-07-10 13:32:49 +08:00 committed by GitHub
commit 2528487847

View File

@ -79,7 +79,7 @@ def fix_valuehead_checkpoint(
if name.startswith("v_head."):
v_head_state_dict[name] = param
else:
decoder_state_dict[name.replace("pretrained_model.", "", count=1)] = param
decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param
model.pretrained_model.save_pretrained(
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization