mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
fix #2982
This commit is contained in:
@@ -55,11 +55,11 @@ def run_ppo(
|
||||
seed=training_args.seed,
|
||||
optimize_device_cache=True,
|
||||
target=finetuning_args.ppo_target,
|
||||
log_with=finetuning_args.ppo_logger,
|
||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||
use_score_norm=finetuning_args.ppo_score_norm,
|
||||
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
||||
log_with=training_args.report_to[0] if training_args.report_to is not None else None,
|
||||
project_kwargs={"logging_dir": training_args.logging_dir},
|
||||
)
|
||||
|
||||
@@ -71,10 +71,10 @@ def run_ppo(
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
create_custom_scheduler(training_args, num_training_steps, optimizer)
|
||||
if optimizer is None:
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
|
||||
create_custom_scheduler(training_args, num_training_steps, optimizer)
|
||||
lr_scheduler = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
|
||||
Reference in New Issue
Block a user