Former-commit-id: b0b0138e1d841bf5bd2e08b2eea78057b33be69a
This commit is contained in:
hiyouga 2023-09-27 22:57:09 +08:00
parent ca6a3bc76f
commit f66e6b91c7
2 changed files with 10 additions and 1 deletions

View File

@ -110,9 +110,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
# Run PPO step # Run PPO step
stats = self.step(queries, responses, rewards) stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.config.log_with is not None:
try:
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards)
except:
logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1 self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control) self.log_callback.on_step_end(self.args, self.state, self.control)

View File

@ -42,6 +42,7 @@ def run_ppo(
ppo_epochs=1, ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm, max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed, seed=training_args.seed,
log_with=training_args.report_to,
optimize_cuda_cache=True, optimize_cuda_cache=True,
accelerator_kwargs={"step_scheduler_with_optimizer": False} accelerator_kwargs={"step_scheduler_with_optimizer": False}
) )