mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-16 00:28:10 +08:00
Update callbacks.py
Former-commit-id: 526376967deaad73b7ca11063a2e3f0c9a0add98
This commit is contained in:
parent
18057e14ef
commit
460a40756c
@ -79,7 +79,7 @@ def fix_valuehead_checkpoint(
|
|||||||
if name.startswith("v_head."):
|
if name.startswith("v_head."):
|
||||||
v_head_state_dict[name] = param
|
v_head_state_dict[name] = param
|
||||||
else:
|
else:
|
||||||
decoder_state_dict[name.replace("pretrained_model.", "",1)] = param
|
decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param
|
||||||
|
|
||||||
model.pretrained_model.save_pretrained(
|
model.pretrained_model.save_pretrained(
|
||||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||||
|
Loading…
x
Reference in New Issue
Block a user