mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	support activation offloading via unsloth gc
Former-commit-id: d3d0dd0feba3ca6f0ae970d5856bec989d26ef67
This commit is contained in:
		
							parent
							
								
									7f71276ad8
								
							
						
					
					
						commit
						294a103ead
					
				@ -109,6 +109,7 @@ def calculate_mfu(
 | 
			
		||||
    deepspeed_stage: int = 0,
 | 
			
		||||
    disable_gc: bool = False,
 | 
			
		||||
    liger_kernel: bool = False,
 | 
			
		||||
    unsloth_gc: bool = False,
 | 
			
		||||
) -> float:
 | 
			
		||||
    r"""
 | 
			
		||||
    Calculates MFU for given model and hyper-params.
 | 
			
		||||
@ -119,6 +120,7 @@ def calculate_mfu(
 | 
			
		||||
        "flash_attn": flash_attn,
 | 
			
		||||
        "disable_gradient_checkpointing": disable_gc,
 | 
			
		||||
        "enable_liger_kernel": liger_kernel,
 | 
			
		||||
        "use_unsloth_gc": unsloth_gc,
 | 
			
		||||
        "stage": "pt",
 | 
			
		||||
        "do_train": True,
 | 
			
		||||
        "finetuning_type": finetuning_type,
 | 
			
		||||
 | 
			
		||||
@ -215,6 +215,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
 | 
			
		||||
    )
 | 
			
		||||
    use_unsloth_gc: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to use unsloth's gradient checkpointing."},
 | 
			
		||||
    )
 | 
			
		||||
    enable_liger_kernel: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to enable liger kernel for faster training."},
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,10 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's Transformers and PEFT library.
 | 
			
		||||
# This code is inspired by the HuggingFace's Transformers and PEFT library,
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
 | 
			
		||||
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
 | 
			
		||||
# and the Unsloth library.
 | 
			
		||||
# https://github.com/unslothai/unsloth/blob/July-2024/unsloth/models/_utils.py
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
@ -19,7 +21,7 @@
 | 
			
		||||
import inspect
 | 
			
		||||
from functools import partial
 | 
			
		||||
from types import MethodType
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
@ -36,8 +38,45 @@ if TYPE_CHECKING:
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UnslothGradientCheckpointing(torch.autograd.Function):
 | 
			
		||||
    r"""
 | 
			
		||||
    Saves VRAM by smartly offloading to RAM.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @torch.cuda.amp.custom_fwd
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx: "torch.autograd.Function",
 | 
			
		||||
        forward_function: "torch.Module",
 | 
			
		||||
        hidden_states: "torch.Tensor",
 | 
			
		||||
        *args: Union["torch.Tensor", Any],
 | 
			
		||||
    ) -> "torch.Tensor":
 | 
			
		||||
        saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            output = forward_function(hidden_states, *args)
 | 
			
		||||
 | 
			
		||||
        ctx.save_for_backward(saved_hidden_states)
 | 
			
		||||
        ctx.forward_function = forward_function
 | 
			
		||||
        ctx.args = args
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @torch.cuda.amp.custom_bwd
 | 
			
		||||
    def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
 | 
			
		||||
        (hidden_states,) = ctx.saved_tensors
 | 
			
		||||
        hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
 | 
			
		||||
        hidden_states.requires_grad_(True)
 | 
			
		||||
        with torch.enable_grad():
 | 
			
		||||
            (output,) = ctx.forward_function(hidden_states, *ctx.args)
 | 
			
		||||
 | 
			
		||||
        torch.autograd.backward(output, grad_output)
 | 
			
		||||
        return (None, hidden_states.grad) + (None,) * len(ctx.args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _gradient_checkpointing_enable(
 | 
			
		||||
    self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
 | 
			
		||||
    self: "PreTrainedModel",
 | 
			
		||||
    gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
 | 
			
		||||
    use_unsloth_gc: bool = False,
 | 
			
		||||
) -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Activates gradient checkpointing for the current model.
 | 
			
		||||
@ -52,9 +91,12 @@ def _gradient_checkpointing_enable(
 | 
			
		||||
    if gradient_checkpointing_kwargs is None:
 | 
			
		||||
        gradient_checkpointing_kwargs = {"use_reentrant": True}
 | 
			
		||||
 | 
			
		||||
    gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
 | 
			
		||||
    if use_unsloth_gc:
 | 
			
		||||
        gradient_checkpointing_func = UnslothGradientCheckpointing.apply
 | 
			
		||||
    else:
 | 
			
		||||
        gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
 | 
			
		||||
 | 
			
		||||
    def custom_gradient_checkpointing_func(func, *args, **kwargs):
 | 
			
		||||
    def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
 | 
			
		||||
        module: "torch.nn.Module" = func.__self__
 | 
			
		||||
 | 
			
		||||
        if any(param.requires_grad for param in module.parameters()):
 | 
			
		||||
@ -97,7 +139,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
 | 
			
		||||
        else:
 | 
			
		||||
            # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
 | 
			
		||||
            # According to: https://github.com/huggingface/transformers/issues/28339
 | 
			
		||||
            model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
 | 
			
		||||
            gradient_checkpointing_enable = partial(
 | 
			
		||||
                _gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
 | 
			
		||||
            )
 | 
			
		||||
            model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
 | 
			
		||||
            model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
 | 
			
		||||
            setattr(model.config, "use_cache", False)  # turn off when gradient checkpointing is enabled
 | 
			
		||||
            logger.info("Gradient checkpointing enabled.")
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user