mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	adapt for badam with ds zero3
Former-commit-id: fff2a020ec8713022bd8145f4a7168168ea07ca4
This commit is contained in:
		
							parent
							
								
									4bd276f58f
								
							
						
					
					
						commit
						ba303fd1aa
					
				@ -184,12 +184,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
 | 
			
		||||
    ):
 | 
			
		||||
        raise ValueError("Distributed training does not support layer-wise GaLore.")
 | 
			
		||||
 | 
			
		||||
    if (
 | 
			
		||||
        finetuning_args.use_badam
 | 
			
		||||
        and finetuning_args.badam_mode == "layer"
 | 
			
		||||
        and training_args.parallel_mode.value == "distributed"
 | 
			
		||||
    ):
 | 
			
		||||
        raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
 | 
			
		||||
    # if (
 | 
			
		||||
    #     finetuning_args.use_badam
 | 
			
		||||
    #     and finetuning_args.badam_mode == "layer"
 | 
			
		||||
    #     and training_args.parallel_mode.value == "distributed"
 | 
			
		||||
    # ):
 | 
			
		||||
    #     raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
 | 
			
		||||
 | 
			
		||||
    if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
 | 
			
		||||
        raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
 | 
			
		||||
 | 
			
		||||
@ -55,6 +55,21 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
 | 
			
		||||
            output_dir = output_dir if output_dir is not None else self.args.output_dir
 | 
			
		||||
            getattr(self.processor, "image_processor").save_pretrained(output_dir)
 | 
			
		||||
 | 
			
		||||
    def training_step(self, *args, **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Update the reference to deepspeed optimizer
 | 
			
		||||
        """
 | 
			
		||||
        if self.finetuning_args.use_badam and \
 | 
			
		||||
            self.args.deepspeed_plugin is not None and \
 | 
			
		||||
            self.args.deepspeed_plugin.zero_stage == 3:
 | 
			
		||||
            
 | 
			
		||||
            ds_optim = self.optimizer.optimizer
 | 
			
		||||
            badam_optim = ds_optim.optimizer
 | 
			
		||||
            badam_optim.ds_optimizer = ds_optim
 | 
			
		||||
        
 | 
			
		||||
        return super().training_step(*args, **kwargs)
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
    def prediction_step(
 | 
			
		||||
        self,
 | 
			
		||||
        model: "torch.nn.Module",
 | 
			
		||||
 | 
			
		||||
@ -309,6 +309,12 @@ def _create_badam_optimizer(
 | 
			
		||||
        dict(params=decay_params, weight_decay=training_args.weight_decay),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    ds_zero3_enabled = False
 | 
			
		||||
    if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None:
 | 
			
		||||
        assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}"
 | 
			
		||||
        assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage"
 | 
			
		||||
        ds_zero3_enabled = True
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.badam_mode == "layer":
 | 
			
		||||
        from badam import BlockOptimizer
 | 
			
		||||
 | 
			
		||||
@ -321,6 +327,7 @@ def _create_badam_optimizer(
 | 
			
		||||
            start_block=finetuning_args.badam_start_block,
 | 
			
		||||
            switch_mode=finetuning_args.badam_switch_mode,
 | 
			
		||||
            verbose=finetuning_args.badam_verbose,
 | 
			
		||||
            ds_zero3_enabled=ds_zero3_enabled
 | 
			
		||||
        )
 | 
			
		||||
        logger.info(
 | 
			
		||||
            f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user