mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 16:18:10 +08:00
fix ppo trainer
Former-commit-id: a03b2e5ef0d5d6b1b27753438745385d290cb211
This commit is contained in:
parent
834c4e8ad9
commit
446129ca7a
@ -106,7 +106,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
|
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
|
||||||
]
|
]
|
||||||
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
||||||
if ppo_config.log_with == "tensorboard": # tensorboard raises error about accelerator_kwargs
|
if ppo_config.log_with is not None:
|
||||||
|
logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
|
||||||
ppo_config.log_with = None
|
ppo_config.log_with = None
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
|
Loading…
x
Reference in New Issue
Block a user