diff --git a/scripts/cal_mfu.py b/scripts/cal_mfu.py index 03280b52..99944a6f 100644 --- a/scripts/cal_mfu.py +++ b/scripts/cal_mfu.py @@ -109,6 +109,7 @@ def calculate_mfu( deepspeed_stage: int = 0, disable_gc: bool = False, liger_kernel: bool = False, + unsloth_gc: bool = False, ) -> float: r""" Calculates MFU for given model and hyper-params. @@ -119,6 +120,7 @@ def calculate_mfu( "flash_attn": flash_attn, "disable_gradient_checkpointing": disable_gc, "enable_liger_kernel": liger_kernel, + "use_unsloth_gc": unsloth_gc, "stage": "pt", "do_train": True, "finetuning_type": finetuning_type, diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 6fd23b0c..9e8c019c 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -215,6 +215,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, ) + use_unsloth_gc: bool = field( + default=False, + metadata={"help": "Whether or not to use unsloth's gradient checkpointing."}, + ) enable_liger_kernel: bool = field( default=False, metadata={"help": "Whether or not to enable liger kernel for faster training."}, diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index f4f3d8a5..337b137a 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -1,8 +1,10 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team. # -# This code is inspired by the HuggingFace's Transformers and PEFT library. +# This code is inspired by the HuggingFace's Transformers and PEFT library, # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py # https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py +# and the Unsloth library. +# https://github.com/unslothai/unsloth/blob/July-2024/unsloth/models/_utils.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +21,7 @@ import inspect from functools import partial from types import MethodType -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import torch @@ -36,8 +38,45 @@ if TYPE_CHECKING: logger = get_logger(__name__) +class UnslothGradientCheckpointing(torch.autograd.Function): + r""" + Saves VRAM by smartly offloading to RAM. + """ + + @staticmethod + @torch.cuda.amp.custom_fwd + def forward( + ctx: "torch.autograd.Function", + forward_function: "torch.Module", + hidden_states: "torch.Tensor", + *args: Union["torch.Tensor", Any], + ) -> "torch.Tensor": + saved_hidden_states = hidden_states.to("cpu", non_blocking=True) + with torch.no_grad(): + output = forward_function(hidden_states, *args) + + ctx.save_for_backward(saved_hidden_states) + ctx.forward_function = forward_function + ctx.args = args + return output + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor": + (hidden_states,) = ctx.saved_tensors + hidden_states = hidden_states.to("cuda", non_blocking=True).detach() + hidden_states.requires_grad_(True) + with torch.enable_grad(): + (output,) = ctx.forward_function(hidden_states, *ctx.args) + + torch.autograd.backward(output, grad_output) + return (None, hidden_states.grad) + (None,) * len(ctx.args) + + def _gradient_checkpointing_enable( - self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None + self: "PreTrainedModel", + gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None, + use_unsloth_gc: bool = False, ) -> None: r""" Activates gradient checkpointing for the current model. @@ -52,9 +91,12 @@ def _gradient_checkpointing_enable( if gradient_checkpointing_kwargs is None: gradient_checkpointing_kwargs = {"use_reentrant": True} - gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) + if use_unsloth_gc: + gradient_checkpointing_func = UnslothGradientCheckpointing.apply + else: + gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) - def custom_gradient_checkpointing_func(func, *args, **kwargs): + def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs): module: "torch.nn.Module" = func.__self__ if any(param.requires_grad for param in module.parameters()): @@ -97,7 +139,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum else: # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) # According to: https://github.com/huggingface/transformers/issues/28339 - model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model) + gradient_checkpointing_enable = partial( + _gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc + ) + model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.")