fix callback

Former-commit-id: 22d9a9c2af
This commit is contained in:
hiyouga
2023-07-15 17:18:16 +08:00
parent a696148d6b
commit 70b5232f9a
4 changed files with 9 additions and 5 deletions

View File

@@ -40,6 +40,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.state = TrainerState()
self.control = TrainerControl()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
self._remove_log()
def ppo_train(self, max_target_length: int) -> None:
r"""