alter rewards data type

This commit is contained in:
hiyouga
2023-06-02 14:19:51 +08:00
parent e6126244c1
commit 50d9a20f81
12 changed files with 40 additions and 50 deletions

View File

@@ -109,7 +109,8 @@ class PeftTrainer(Seq2SeqTrainer):
if hasattr(model, "v_head"): # save valuehead weights
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):