diff --git a/examples/extras/badam/sft.sh b/examples/extras/badam/sft.sh index daa63913..656cfdba 100644 --- a/examples/extras/badam/sft.sh +++ b/examples/extras/badam/sft.sh @@ -31,6 +31,5 @@ python ../../../src/train_bash.py \ --use_badam \ --switch_mode descending \ --badam_verbose 2 \ ---switch_block_every 50 \ ---pure_bf16 \ +--switch_block_every 50 diff --git a/setup.py b/setup.py index fd5bdf7e..b2eb4afd 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ extra_require = { "metrics": ["nltk", "jieba", "rouge-chinese"], "unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"], "galore": ["galore-torch"], + "badam": ["torch>=2.1.0"], "vllm": ["vllm>=0.3.3"], "bitsandbytes": ["bitsandbytes>=0.39.0"], "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index e83a903e..fd587efd 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -150,30 +150,24 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. """ from torch.utils.checkpoint import checkpoint + import functools if not self.supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") if gradient_checkpointing_kwargs is None: - gradient_checkpointing_kwargs = {} + gradient_checkpointing_kwargs = {"use_reentrant": True} - # gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) + checkpoint = functools.partial(checkpoint, **gradient_checkpointing_kwargs) def gradient_checkpointing_func(func, *args, **kwargs): module = func.__self__ - if any([p.requires_grad for p in module.parameters()]): + if any(p.requires_grad for p in module.parameters()): for arg in args: if torch.is_tensor(arg) and torch.is_floating_point(arg): arg.requires_grad_(True) return checkpoint(func, *args, **kwargs) - self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) - - if getattr(self, "_hf_peft_config_loaded", False): - # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True - # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 - # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate - # the gradients to make sure the gradient flows. - self.enable_input_require_grads() \ No newline at end of file + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) \ No newline at end of file diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index d750f491..de741426 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -29,7 +29,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None: super().__init__(**kwargs) self.finetuning_args = finetuning_args - if version.parse(torch.__version__) >= version.parse("1.13"): + if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)