mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
parent
654116c0b1
commit
28e613efd0
@ -99,10 +99,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add deepspeed config
|
# Add deepspeed config
|
||||||
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
|
|
||||||
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
|
|
||||||
]
|
|
||||||
if training_args.deepspeed_plugin is not None:
|
if training_args.deepspeed_plugin is not None:
|
||||||
|
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
|
||||||
|
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
|
||||||
|
]
|
||||||
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
|
Loading…
x
Reference in New Issue
Block a user