mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -21,7 +21,7 @@
|
||||
import inspect
|
||||
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -40,9 +40,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
class UnslothGradientCheckpointing(torch.autograd.Function):
|
||||
r"""
|
||||
Saves VRAM by smartly offloading to RAM.
|
||||
"""
|
||||
r"""Saves VRAM by smartly offloading to RAM."""
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
@@ -77,13 +75,11 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
|
||||
|
||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
|
||||
r"""
|
||||
Only applies gradient checkpointing to trainable layers.
|
||||
"""
|
||||
r"""Only applies gradient checkpointing to trainable layers."""
|
||||
|
||||
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
module: torch.nn.Module = func.__self__
|
||||
|
||||
has_grad = False
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
@@ -103,11 +99,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
|
||||
def _gradient_checkpointing_enable(
|
||||
self: "PreTrainedModel",
|
||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
|
||||
gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
|
||||
use_unsloth_gc: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates gradient checkpointing for the current model.
|
||||
r"""Activates gradient checkpointing for the current model.
|
||||
|
||||
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
||||
"""
|
||||
@@ -134,17 +129,18 @@ def _gradient_checkpointing_enable(
|
||||
|
||||
|
||||
def _fp32_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(torch.float32)
|
||||
|
||||
|
||||
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
r"""Prepare the model before training.
|
||||
|
||||
Include:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32.
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info_rank0("Upcasting layernorm weights in float32.")
|
||||
|
||||
Reference in New Issue
Block a user