adapt for badam with ds zero3

Former-commit-id: 33b4372778
This commit is contained in:
Jonery
2024-06-17 18:18:10 +08:00
parent db569a2d61
commit 756566342d
3 changed files with 28 additions and 6 deletions

View File

@@ -55,6 +55,21 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def training_step(self, *args, **kwargs):
r"""
Update the reference to deepspeed optimizer
"""
if self.finetuning_args.use_badam and \
self.args.deepspeed_plugin is not None and \
self.args.deepspeed_plugin.zero_stage == 3:
ds_optim = self.optimizer.optimizer
badam_optim = ds_optim.optimizer
badam_optim.ds_optimizer = ds_optim
return super().training_step(*args, **kwargs)
def prediction_step(
self,
model: "torch.nn.Module",

View File

@@ -309,6 +309,12 @@ def _create_badam_optimizer(
dict(params=decay_params, weight_decay=training_args.weight_decay),
]
ds_zero3_enabled = False
if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None:
assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}"
assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage"
ds_zero3_enabled = True
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
@@ -321,6 +327,7 @@ def _create_badam_optimizer(
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
ds_zero3_enabled=ds_zero3_enabled
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "