diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py index 41a99e2c..4fb9d593 100644 --- a/src/llmtuner/train/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -45,7 +45,7 @@ def run_ppo( mini_batch_size=training_args.per_device_train_batch_size, batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps, - ppo_epochs=1, + ppo_epochs=finetuning_args.ppo_epochs, max_grad_norm=training_args.max_grad_norm, seed=training_args.seed, optimize_device_cache=True,