mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2026-05-08 00:58:55 +08:00
init
This commit is contained in:
275
primitive_anything/utils/__init__.py
Executable file
275
primitive_anything/utils/__init__.py
Executable file
@@ -0,0 +1,275 @@
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
import os
|
||||
import re
|
||||
|
||||
from beartype.typing import Tuple
|
||||
from einops import rearrange, repeat
|
||||
from toolz import valmap
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
import yaml
|
||||
|
||||
from .typing import typecheck
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def path_mkdir(path):
|
||||
path = Path(path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def load_yaml(path, default_path=None):
|
||||
path = path_exists(path)
|
||||
with open(path, mode='r') as fp:
|
||||
cfg_s = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
|
||||
if default_path is not None:
|
||||
default_path = path_exists(default_path)
|
||||
with open(default_path, mode='r') as fp:
|
||||
cfg = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
else:
|
||||
# try current dir default
|
||||
default_path = path.parent / 'default.yml'
|
||||
if default_path.exists():
|
||||
with open(default_path, mode='r') as fp:
|
||||
cfg = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
else:
|
||||
cfg = {}
|
||||
|
||||
update_recursive(cfg, cfg_s)
|
||||
return cfg
|
||||
|
||||
|
||||
def dump_yaml(cfg, path):
|
||||
with open(path, mode='w') as f:
|
||||
return yaml.safe_dump(cfg, f)
|
||||
|
||||
|
||||
def update_recursive(dict1, dict2):
|
||||
''' Update two config dictionaries recursively.
|
||||
Args:
|
||||
dict1 (dict): first dictionary to be updated
|
||||
dict2 (dict): second dictionary which entries should be used
|
||||
'''
|
||||
for k, v in dict2.items():
|
||||
if k not in dict1:
|
||||
dict1[k] = dict()
|
||||
if isinstance(v, dict):
|
||||
update_recursive(dict1[k], v)
|
||||
else:
|
||||
dict1[k] = v
|
||||
|
||||
|
||||
def load_latest_checkpoint(checkpoint_dir):
|
||||
pattern = re.compile(rf".+\.ckpt\.(\d+)\.pt")
|
||||
max_epoch = -1
|
||||
latest_checkpoint = None
|
||||
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
match = pattern.match(filename)
|
||||
if match:
|
||||
num_epoch = int(match.group(1))
|
||||
if num_epoch > max_epoch:
|
||||
max_epoch = num_epoch
|
||||
latest_checkpoint = checkpoint_dir / filename
|
||||
|
||||
if not exists(latest_checkpoint):
|
||||
raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}")
|
||||
|
||||
checkpoint = torch.load(latest_checkpoint)
|
||||
return checkpoint, latest_checkpoint
|
||||
|
||||
|
||||
def torch_to(inp, device, non_blocking=False):
|
||||
nb = non_blocking # set to True when doing distributed jobs
|
||||
if isinstance(inp, torch.Tensor):
|
||||
return inp.to(device, non_blocking=nb)
|
||||
elif isinstance(inp, (list, tuple)):
|
||||
return type(inp)(map(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp))
|
||||
elif isinstance(inp, dict):
|
||||
return valmap(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def first(it):
|
||||
return it[0]
|
||||
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def is_odd(n):
|
||||
return not divisible_by(n, 2)
|
||||
|
||||
def is_empty(x):
|
||||
return len(x) == 0
|
||||
|
||||
def is_tensor_empty(t: Tensor):
|
||||
return t.numel() == 0
|
||||
|
||||
def set_module_requires_grad_(
|
||||
module: Module,
|
||||
requires_grad: bool
|
||||
):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
def l1norm(t):
|
||||
return F.normalize(t, dim = -1, p = 1)
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
def safe_cat(tensors, dim):
|
||||
tensors = [*filter(exists, tensors)]
|
||||
|
||||
if len(tensors) == 0:
|
||||
return None
|
||||
elif len(tensors) == 1:
|
||||
return first(tensors)
|
||||
|
||||
return torch.cat(tensors, dim = dim)
|
||||
|
||||
def pad_at_dim(t, padding, dim = -1, value = 0):
|
||||
ndim = t.ndim
|
||||
right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1)
|
||||
zeros = (0, 0) * right_dims
|
||||
return F.pad(t, (*zeros, *padding), value = value)
|
||||
|
||||
def pad_to_length(t, length, dim = -1, value = 0, right = True):
|
||||
curr_length = t.shape[dim]
|
||||
remainder = length - curr_length
|
||||
|
||||
if remainder <= 0:
|
||||
return t
|
||||
|
||||
padding = (0, remainder) if right else (remainder, 0)
|
||||
return pad_at_dim(t, padding, dim = dim, value = value)
|
||||
|
||||
def masked_mean(tensor, mask, dim = -1, eps = 1e-5):
|
||||
if not exists(mask):
|
||||
return tensor.mean(dim = dim)
|
||||
|
||||
mask = rearrange(mask, '... -> ... 1')
|
||||
tensor = tensor.masked_fill(~mask, 0.)
|
||||
|
||||
total_el = mask.sum(dim = dim)
|
||||
num = tensor.sum(dim = dim)
|
||||
den = total_el.float().clamp(min = eps)
|
||||
mean = num / den
|
||||
mean = mean.masked_fill(total_el == 0, 0.)
|
||||
return mean
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
def maybe_del(d: dict, *keys):
|
||||
for key in keys:
|
||||
if key not in d:
|
||||
continue
|
||||
|
||||
del d[key]
|
||||
|
||||
|
||||
# tensor helper functions
|
||||
|
||||
@typecheck
|
||||
def discretize(
|
||||
t: Tensor,
|
||||
*,
|
||||
continuous_range: Tuple[float, float],
|
||||
num_discrete: int = 128
|
||||
) -> Tensor:
|
||||
lo, hi = continuous_range
|
||||
assert hi > lo
|
||||
|
||||
t = (t - lo) / (hi - lo)
|
||||
t *= num_discrete
|
||||
t -= 0.5
|
||||
|
||||
return t.round().long().clamp(min = 0, max = num_discrete - 1)
|
||||
|
||||
@typecheck
|
||||
def undiscretize(
|
||||
t: Tensor,
|
||||
*,
|
||||
continuous_range = Tuple[float, float],
|
||||
num_discrete: int = 128
|
||||
) -> Tensor:
|
||||
lo, hi = continuous_range
|
||||
assert hi > lo
|
||||
|
||||
t = t.float()
|
||||
|
||||
t += 0.5
|
||||
t /= num_discrete
|
||||
return t * (hi - lo) + lo
|
||||
|
||||
@typecheck
|
||||
def gaussian_blur_1d(
|
||||
t: Tensor,
|
||||
*,
|
||||
sigma: float = 1.,
|
||||
kernel_size: int = 5
|
||||
) -> Tensor:
|
||||
|
||||
_, _, channels, device, dtype = *t.shape, t.device, t.dtype
|
||||
|
||||
width = int(ceil(sigma * kernel_size))
|
||||
width += (width + 1) % 2
|
||||
half_width = width // 2
|
||||
|
||||
distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device)
|
||||
|
||||
gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2))
|
||||
gaussian = l1norm(gaussian)
|
||||
|
||||
kernel = repeat(gaussian, 'n -> c 1 n', c = channels)
|
||||
|
||||
t = rearrange(t, 'b n c -> b c n')
|
||||
out = F.conv1d(t, kernel, padding = half_width, groups = channels)
|
||||
return rearrange(out, 'b c n -> b n c')
|
||||
|
||||
@typecheck
|
||||
def scatter_mean(
|
||||
tgt: Tensor,
|
||||
indices: Tensor,
|
||||
src = Tensor,
|
||||
*,
|
||||
dim: int = -1,
|
||||
eps: float = 1e-5
|
||||
):
|
||||
"""
|
||||
todo: update to pytorch 2.1 and try https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_
|
||||
"""
|
||||
num = tgt.scatter_add(dim, indices, src)
|
||||
den = torch.zeros_like(tgt).scatter_add(dim, indices, torch.ones_like(src))
|
||||
return num / den.clamp(min = eps)
|
||||
|
||||
def prob_mask_like(shape, prob, device):
|
||||
if prob == 1:
|
||||
return torch.ones(shape, device = device, dtype = torch.bool)
|
||||
elif prob == 0:
|
||||
return torch.zeros(shape, device = device, dtype = torch.bool)
|
||||
else:
|
||||
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
|
||||
Reference in New Issue
Block a user