Former-commit-id: d0940d0dbd03d4bbcc955304566b0d5507edf9e6
This commit is contained in:
hiyouga 2023-09-27 22:57:09 +08:00
parent dd623325e8
commit 6c5d8f089e
2 changed files with 10 additions and 1 deletions

View File

@ -110,9 +110,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
# Run PPO step
stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.config.log_with is not None:
try:
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards)
except:
logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control)

View File

@ -42,6 +42,7 @@ def run_ppo(
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
log_with=training_args.report_to,
optimize_cuda_cache=True,
accelerator_kwargs={"step_scheduler_with_optimizer": False}
)