mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
ca6a3bc76f
commit
f66e6b91c7
@ -110,9 +110,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
# Run PPO step
|
||||
stats = self.step(queries, responses, rewards)
|
||||
self.tokenizer.padding_side = "left" # restore padding side
|
||||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
||||
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||
|
||||
if self.config.log_with is not None:
|
||||
try:
|
||||
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
|
||||
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||
self.log_stats(stats, batch, rewards)
|
||||
except:
|
||||
logger.warning("Failed to save stats due to unknown errors.")
|
||||
|
||||
self.state.global_step += 1
|
||||
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||
|
||||
|
@ -42,6 +42,7 @@ def run_ppo(
|
||||
ppo_epochs=1,
|
||||
max_grad_norm=training_args.max_grad_norm,
|
||||
seed=training_args.seed,
|
||||
log_with=training_args.report_to,
|
||||
optimize_cuda_cache=True,
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user