mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-09-18 05:22:48 +08:00
275 lines
6.8 KiB
Python
Executable File
275 lines
6.8 KiB
Python
Executable File
from math import ceil
|
|
from pathlib import Path
|
|
import os
|
|
import re
|
|
|
|
from beartype.typing import Tuple
|
|
from einops import rearrange, repeat
|
|
from toolz import valmap
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.nn import Module
|
|
import torch.nn.functional as F
|
|
import yaml
|
|
|
|
from .typing import typecheck
|
|
|
|
|
|
def count_parameters(model):
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
def path_mkdir(path):
|
|
path = Path(path)
|
|
path.mkdir(parents=True, exist_ok=True)
|
|
return path
|
|
|
|
|
|
def load_yaml(path, default_path=None):
|
|
path = path_exists(path)
|
|
with open(path, mode='r') as fp:
|
|
cfg_s = yaml.load(fp, Loader=yaml.FullLoader)
|
|
|
|
if default_path is not None:
|
|
default_path = path_exists(default_path)
|
|
with open(default_path, mode='r') as fp:
|
|
cfg = yaml.load(fp, Loader=yaml.FullLoader)
|
|
else:
|
|
# try current dir default
|
|
default_path = path.parent / 'default.yml'
|
|
if default_path.exists():
|
|
with open(default_path, mode='r') as fp:
|
|
cfg = yaml.load(fp, Loader=yaml.FullLoader)
|
|
else:
|
|
cfg = {}
|
|
|
|
update_recursive(cfg, cfg_s)
|
|
return cfg
|
|
|
|
|
|
def dump_yaml(cfg, path):
|
|
with open(path, mode='w') as f:
|
|
return yaml.safe_dump(cfg, f)
|
|
|
|
|
|
def update_recursive(dict1, dict2):
|
|
''' Update two config dictionaries recursively.
|
|
Args:
|
|
dict1 (dict): first dictionary to be updated
|
|
dict2 (dict): second dictionary which entries should be used
|
|
'''
|
|
for k, v in dict2.items():
|
|
if k not in dict1:
|
|
dict1[k] = dict()
|
|
if isinstance(v, dict):
|
|
update_recursive(dict1[k], v)
|
|
else:
|
|
dict1[k] = v
|
|
|
|
|
|
def load_latest_checkpoint(checkpoint_dir):
|
|
pattern = re.compile(rf".+\.ckpt\.(\d+)\.pt")
|
|
max_epoch = -1
|
|
latest_checkpoint = None
|
|
|
|
for filename in os.listdir(checkpoint_dir):
|
|
match = pattern.match(filename)
|
|
if match:
|
|
num_epoch = int(match.group(1))
|
|
if num_epoch > max_epoch:
|
|
max_epoch = num_epoch
|
|
latest_checkpoint = checkpoint_dir / filename
|
|
|
|
if not exists(latest_checkpoint):
|
|
raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}")
|
|
|
|
checkpoint = torch.load(latest_checkpoint)
|
|
return checkpoint, latest_checkpoint
|
|
|
|
|
|
def torch_to(inp, device, non_blocking=False):
|
|
nb = non_blocking # set to True when doing distributed jobs
|
|
if isinstance(inp, torch.Tensor):
|
|
return inp.to(device, non_blocking=nb)
|
|
elif isinstance(inp, (list, tuple)):
|
|
return type(inp)(map(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp))
|
|
elif isinstance(inp, dict):
|
|
return valmap(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
# helper functions
|
|
|
|
def exists(v):
|
|
return v is not None
|
|
|
|
def default(v, d):
|
|
return v if exists(v) else d
|
|
|
|
def first(it):
|
|
return it[0]
|
|
|
|
def identity(t, *args, **kwargs):
|
|
return t
|
|
|
|
def divisible_by(num, den):
|
|
return (num % den) == 0
|
|
|
|
def is_odd(n):
|
|
return not divisible_by(n, 2)
|
|
|
|
def is_empty(x):
|
|
return len(x) == 0
|
|
|
|
def is_tensor_empty(t: Tensor):
|
|
return t.numel() == 0
|
|
|
|
def set_module_requires_grad_(
|
|
module: Module,
|
|
requires_grad: bool
|
|
):
|
|
for param in module.parameters():
|
|
param.requires_grad = requires_grad
|
|
|
|
def l1norm(t):
|
|
return F.normalize(t, dim = -1, p = 1)
|
|
|
|
def l2norm(t):
|
|
return F.normalize(t, dim = -1, p = 2)
|
|
|
|
def safe_cat(tensors, dim):
|
|
tensors = [*filter(exists, tensors)]
|
|
|
|
if len(tensors) == 0:
|
|
return None
|
|
elif len(tensors) == 1:
|
|
return first(tensors)
|
|
|
|
return torch.cat(tensors, dim = dim)
|
|
|
|
def pad_at_dim(t, padding, dim = -1, value = 0):
|
|
ndim = t.ndim
|
|
right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1)
|
|
zeros = (0, 0) * right_dims
|
|
return F.pad(t, (*zeros, *padding), value = value)
|
|
|
|
def pad_to_length(t, length, dim = -1, value = 0, right = True):
|
|
curr_length = t.shape[dim]
|
|
remainder = length - curr_length
|
|
|
|
if remainder <= 0:
|
|
return t
|
|
|
|
padding = (0, remainder) if right else (remainder, 0)
|
|
return pad_at_dim(t, padding, dim = dim, value = value)
|
|
|
|
def masked_mean(tensor, mask, dim = -1, eps = 1e-5):
|
|
if not exists(mask):
|
|
return tensor.mean(dim = dim)
|
|
|
|
mask = rearrange(mask, '... -> ... 1')
|
|
tensor = tensor.masked_fill(~mask, 0.)
|
|
|
|
total_el = mask.sum(dim = dim)
|
|
num = tensor.sum(dim = dim)
|
|
den = total_el.float().clamp(min = eps)
|
|
mean = num / den
|
|
mean = mean.masked_fill(total_el == 0, 0.)
|
|
return mean
|
|
|
|
def cycle(dl):
|
|
while True:
|
|
for data in dl:
|
|
yield data
|
|
|
|
def maybe_del(d: dict, *keys):
|
|
for key in keys:
|
|
if key not in d:
|
|
continue
|
|
|
|
del d[key]
|
|
|
|
|
|
# tensor helper functions
|
|
|
|
@typecheck
|
|
def discretize(
|
|
t: Tensor,
|
|
*,
|
|
continuous_range: Tuple[float, float],
|
|
num_discrete: int = 128
|
|
) -> Tensor:
|
|
lo, hi = continuous_range
|
|
assert hi > lo
|
|
|
|
t = (t - lo) / (hi - lo)
|
|
t *= num_discrete
|
|
t -= 0.5
|
|
|
|
return t.round().long().clamp(min = 0, max = num_discrete - 1)
|
|
|
|
@typecheck
|
|
def undiscretize(
|
|
t: Tensor,
|
|
*,
|
|
continuous_range = Tuple[float, float],
|
|
num_discrete: int = 128
|
|
) -> Tensor:
|
|
lo, hi = continuous_range
|
|
assert hi > lo
|
|
|
|
t = t.float()
|
|
|
|
t += 0.5
|
|
t /= num_discrete
|
|
return t * (hi - lo) + lo
|
|
|
|
@typecheck
|
|
def gaussian_blur_1d(
|
|
t: Tensor,
|
|
*,
|
|
sigma: float = 1.,
|
|
kernel_size: int = 5
|
|
) -> Tensor:
|
|
|
|
_, _, channels, device, dtype = *t.shape, t.device, t.dtype
|
|
|
|
width = int(ceil(sigma * kernel_size))
|
|
width += (width + 1) % 2
|
|
half_width = width // 2
|
|
|
|
distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device)
|
|
|
|
gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2))
|
|
gaussian = l1norm(gaussian)
|
|
|
|
kernel = repeat(gaussian, 'n -> c 1 n', c = channels)
|
|
|
|
t = rearrange(t, 'b n c -> b c n')
|
|
out = F.conv1d(t, kernel, padding = half_width, groups = channels)
|
|
return rearrange(out, 'b c n -> b n c')
|
|
|
|
@typecheck
|
|
def scatter_mean(
|
|
tgt: Tensor,
|
|
indices: Tensor,
|
|
src = Tensor,
|
|
*,
|
|
dim: int = -1,
|
|
eps: float = 1e-5
|
|
):
|
|
"""
|
|
todo: update to pytorch 2.1 and try https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_
|
|
"""
|
|
num = tgt.scatter_add(dim, indices, src)
|
|
den = torch.zeros_like(tgt).scatter_add(dim, indices, torch.ones_like(src))
|
|
return num / den.clamp(min = eps)
|
|
|
|
def prob_mask_like(shape, prob, device):
|
|
if prob == 1:
|
|
return torch.ones(shape, device = device, dtype = torch.bool)
|
|
elif prob == 0:
|
|
return torch.zeros(shape, device = device, dtype = torch.bool)
|
|
else:
|
|
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob |