diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 6311297e..fe108657 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -184,12 +184,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ): raise ValueError("Distributed training does not support layer-wise GaLore.") - if ( - finetuning_args.use_badam - and finetuning_args.badam_mode == "layer" - and training_args.parallel_mode.value == "distributed" - ): - raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") + # if ( + # finetuning_args.use_badam + # and finetuning_args.badam_mode == "layer" + # and training_args.parallel_mode.value == "distributed" + # ): + # raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None: raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.") diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 35671e1b..cd73bf5c 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -55,6 +55,21 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) + def training_step(self, *args, **kwargs): + r""" + Update the reference to deepspeed optimizer + """ + if self.finetuning_args.use_badam and \ + self.args.deepspeed_plugin is not None and \ + self.args.deepspeed_plugin.zero_stage == 3: + + ds_optim = self.optimizer.optimizer + badam_optim = ds_optim.optimizer + badam_optim.ds_optimizer = ds_optim + + return super().training_step(*args, **kwargs) + + def prediction_step( self, model: "torch.nn.Module", diff --git a/src/llamafactory/train/utils.py b/src/llamafactory/train/utils.py index 23834f2d..b189922b 100644 --- a/src/llamafactory/train/utils.py +++ b/src/llamafactory/train/utils.py @@ -309,6 +309,12 @@ def _create_badam_optimizer( dict(params=decay_params, weight_decay=training_args.weight_decay), ] + ds_zero3_enabled = False + if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None: + assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}" + assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage" + ds_zero3_enabled = True + if finetuning_args.badam_mode == "layer": from badam import BlockOptimizer @@ -321,6 +327,7 @@ def _create_badam_optimizer( start_block=finetuning_args.badam_start_block, switch_mode=finetuning_args.badam_switch_mode, verbose=finetuning_args.badam_verbose, + ds_zero3_enabled=ds_zero3_enabled ) logger.info( f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "