This commit is contained in:
hiyouga
2024-03-28 20:22:31 +08:00
parent 6c94305e47
commit 8d603f8820
4 changed files with 12 additions and 14 deletions

View File

@@ -55,11 +55,11 @@ def run_ppo(
seed=training_args.seed,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False},
log_with=training_args.report_to[0] if training_args.report_to is not None else None,
project_kwargs={"logging_dir": training_args.logging_dir},
)
@@ -71,10 +71,10 @@ def run_ppo(
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
create_custom_scheduler(training_args, num_training_steps, optimizer)
if optimizer is None:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
create_custom_scheduler(training_args, num_training_steps, optimizer)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,