mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
ca6a3bc76f
commit
f66e6b91c7
@ -110,9 +110,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
# Run PPO step
|
# Run PPO step
|
||||||
stats = self.step(queries, responses, rewards)
|
stats = self.step(queries, responses, rewards)
|
||||||
self.tokenizer.padding_side = "left" # restore padding side
|
self.tokenizer.padding_side = "left" # restore padding side
|
||||||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
||||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||||
|
|
||||||
|
if self.config.log_with is not None:
|
||||||
|
try:
|
||||||
|
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
|
||||||
|
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||||
|
self.log_stats(stats, batch, rewards)
|
||||||
|
except:
|
||||||
|
logger.warning("Failed to save stats due to unknown errors.")
|
||||||
|
|
||||||
self.state.global_step += 1
|
self.state.global_step += 1
|
||||||
self.log_callback.on_step_end(self.args, self.state, self.control)
|
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
|
@ -42,6 +42,7 @@ def run_ppo(
|
|||||||
ppo_epochs=1,
|
ppo_epochs=1,
|
||||||
max_grad_norm=training_args.max_grad_norm,
|
max_grad_norm=training_args.max_grad_norm,
|
||||||
seed=training_args.seed,
|
seed=training_args.seed,
|
||||||
|
log_with=training_args.report_to,
|
||||||
optimize_cuda_cache=True,
|
optimize_cuda_cache=True,
|
||||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user