mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-09-18 05:22:48 +08:00
70 lines
2.4 KiB
Python
Executable File
70 lines
2.4 KiB
Python
Executable File
# -*- coding: utf-8 -*-
|
|
"""
|
|
Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
|
|
"""
|
|
|
|
import torch
|
|
from typing import Callable, Iterable, Sequence, Union
|
|
|
|
|
|
def checkpoint(
|
|
func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
|
|
inputs: Sequence[torch.Tensor],
|
|
params: Iterable[torch.Tensor],
|
|
flag: bool,
|
|
use_deepspeed: bool = False
|
|
):
|
|
"""
|
|
Evaluate a function without caching intermediate activations, allowing for
|
|
reduced memory at the expense of extra compute in the backward pass.
|
|
:param func: the function to evaluate.
|
|
:param inputs: the argument sequence to pass to `func`.
|
|
:param params: a sequence of parameters `func` depends on but does not
|
|
explicitly take as arguments.
|
|
:param flag: if False, disable gradient checkpointing.
|
|
:param use_deepspeed: if True, use deepspeed
|
|
"""
|
|
if flag:
|
|
if use_deepspeed:
|
|
import deepspeed
|
|
return deepspeed.checkpointing.checkpoint(func, *inputs)
|
|
|
|
args = tuple(inputs) + tuple(params)
|
|
return CheckpointFunction.apply(func, len(inputs), *args)
|
|
else:
|
|
return func(*inputs)
|
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
@torch.cuda.amp.custom_fwd
|
|
def forward(ctx, run_function, length, *args):
|
|
ctx.run_function = run_function
|
|
ctx.input_tensors = list(args[:length])
|
|
ctx.input_params = list(args[length:])
|
|
|
|
with torch.no_grad():
|
|
output_tensors = ctx.run_function(*ctx.input_tensors)
|
|
return output_tensors
|
|
|
|
@staticmethod
|
|
@torch.cuda.amp.custom_bwd
|
|
def backward(ctx, *output_grads):
|
|
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
|
with torch.enable_grad():
|
|
# Fixes a bug where the first op in run_function modifies the
|
|
# Tensor storage in place, which is not allowed for detach()'d
|
|
# Tensors.
|
|
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
|
output_tensors = ctx.run_function(*shallow_copies)
|
|
input_grads = torch.autograd.grad(
|
|
output_tensors,
|
|
ctx.input_tensors + ctx.input_params,
|
|
output_grads,
|
|
allow_unused=True,
|
|
)
|
|
del ctx.input_tensors
|
|
del ctx.input_params
|
|
del output_tensors
|
|
return (None, None) + input_grads
|