mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
support activation offloading via unsloth gc
Former-commit-id: fb72a3adb0916232cc9ac9f0c725c02d07b9354c
This commit is contained in:
parent
7ccb86b215
commit
0daee7cb39
@ -109,6 +109,7 @@ def calculate_mfu(
|
|||||||
deepspeed_stage: int = 0,
|
deepspeed_stage: int = 0,
|
||||||
disable_gc: bool = False,
|
disable_gc: bool = False,
|
||||||
liger_kernel: bool = False,
|
liger_kernel: bool = False,
|
||||||
|
unsloth_gc: bool = False,
|
||||||
) -> float:
|
) -> float:
|
||||||
r"""
|
r"""
|
||||||
Calculates MFU for given model and hyper-params.
|
Calculates MFU for given model and hyper-params.
|
||||||
@ -119,6 +120,7 @@ def calculate_mfu(
|
|||||||
"flash_attn": flash_attn,
|
"flash_attn": flash_attn,
|
||||||
"disable_gradient_checkpointing": disable_gc,
|
"disable_gradient_checkpointing": disable_gc,
|
||||||
"enable_liger_kernel": liger_kernel,
|
"enable_liger_kernel": liger_kernel,
|
||||||
|
"use_unsloth_gc": unsloth_gc,
|
||||||
"stage": "pt",
|
"stage": "pt",
|
||||||
"do_train": True,
|
"do_train": True,
|
||||||
"finetuning_type": finetuning_type,
|
"finetuning_type": finetuning_type,
|
||||||
|
@ -215,6 +215,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||||
)
|
)
|
||||||
|
use_unsloth_gc: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use unsloth's gradient checkpointing."},
|
||||||
|
)
|
||||||
enable_liger_kernel: bool = field(
|
enable_liger_kernel: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by the HuggingFace's Transformers and PEFT library.
|
# This code is inspired by the HuggingFace's Transformers and PEFT library,
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
|
||||||
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
|
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
|
||||||
|
# and the Unsloth library.
|
||||||
|
# https://github.com/unslothai/unsloth/blob/July-2024/unsloth/models/_utils.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -19,7 +21,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -36,8 +38,45 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UnslothGradientCheckpointing(torch.autograd.Function):
|
||||||
|
r"""
|
||||||
|
Saves VRAM by smartly offloading to RAM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.cuda.amp.custom_fwd
|
||||||
|
def forward(
|
||||||
|
ctx: "torch.autograd.Function",
|
||||||
|
forward_function: "torch.Module",
|
||||||
|
hidden_states: "torch.Tensor",
|
||||||
|
*args: Union["torch.Tensor", Any],
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = forward_function(hidden_states, *args)
|
||||||
|
|
||||||
|
ctx.save_for_backward(saved_hidden_states)
|
||||||
|
ctx.forward_function = forward_function
|
||||||
|
ctx.args = args
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.cuda.amp.custom_bwd
|
||||||
|
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
|
||||||
|
(hidden_states,) = ctx.saved_tensors
|
||||||
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||||
|
hidden_states.requires_grad_(True)
|
||||||
|
with torch.enable_grad():
|
||||||
|
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||||
|
|
||||||
|
torch.autograd.backward(output, grad_output)
|
||||||
|
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
||||||
|
|
||||||
|
|
||||||
def _gradient_checkpointing_enable(
|
def _gradient_checkpointing_enable(
|
||||||
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
self: "PreTrainedModel",
|
||||||
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
use_unsloth_gc: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Activates gradient checkpointing for the current model.
|
Activates gradient checkpointing for the current model.
|
||||||
@ -52,9 +91,12 @@ def _gradient_checkpointing_enable(
|
|||||||
if gradient_checkpointing_kwargs is None:
|
if gradient_checkpointing_kwargs is None:
|
||||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||||
|
|
||||||
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
if use_unsloth_gc:
|
||||||
|
gradient_checkpointing_func = UnslothGradientCheckpointing.apply
|
||||||
|
else:
|
||||||
|
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||||
|
|
||||||
def custom_gradient_checkpointing_func(func, *args, **kwargs):
|
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
|
||||||
module: "torch.nn.Module" = func.__self__
|
module: "torch.nn.Module" = func.__self__
|
||||||
|
|
||||||
if any(param.requires_grad for param in module.parameters()):
|
if any(param.requires_grad for param in module.parameters()):
|
||||||
@ -97,7 +139,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
|||||||
else:
|
else:
|
||||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||||
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
|
gradient_checkpointing_enable = partial(
|
||||||
|
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
|
||||||
|
)
|
||||||
|
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||||
logger.info("Gradient checkpointing enabled.")
|
logger.info("Gradient checkpointing enabled.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user