This commit is contained in:
hiyouga
2024-06-16 01:06:41 +08:00
parent 80a9e6bf94
commit 38b6b0f52e
22 changed files with 27 additions and 25 deletions

View File

@@ -1,6 +1,6 @@
# Copyright 2024 the LlamaFactory team.
#
# This code is inspired by CarperAI's trlx library.
# This code is inspired by the CarperAI's trlx library.
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -89,8 +89,8 @@ class PairwiseTrainer(Trainer):
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def compute_loss(