mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-12-27 18:40:34 +08:00
init
This commit is contained in:
275
primitive_anything/utils/__init__.py
Executable file
275
primitive_anything/utils/__init__.py
Executable file
@@ -0,0 +1,275 @@
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
import os
|
||||
import re
|
||||
|
||||
from beartype.typing import Tuple
|
||||
from einops import rearrange, repeat
|
||||
from toolz import valmap
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
import yaml
|
||||
|
||||
from .typing import typecheck
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def path_mkdir(path):
|
||||
path = Path(path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def load_yaml(path, default_path=None):
|
||||
path = path_exists(path)
|
||||
with open(path, mode='r') as fp:
|
||||
cfg_s = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
|
||||
if default_path is not None:
|
||||
default_path = path_exists(default_path)
|
||||
with open(default_path, mode='r') as fp:
|
||||
cfg = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
else:
|
||||
# try current dir default
|
||||
default_path = path.parent / 'default.yml'
|
||||
if default_path.exists():
|
||||
with open(default_path, mode='r') as fp:
|
||||
cfg = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
else:
|
||||
cfg = {}
|
||||
|
||||
update_recursive(cfg, cfg_s)
|
||||
return cfg
|
||||
|
||||
|
||||
def dump_yaml(cfg, path):
|
||||
with open(path, mode='w') as f:
|
||||
return yaml.safe_dump(cfg, f)
|
||||
|
||||
|
||||
def update_recursive(dict1, dict2):
|
||||
''' Update two config dictionaries recursively.
|
||||
Args:
|
||||
dict1 (dict): first dictionary to be updated
|
||||
dict2 (dict): second dictionary which entries should be used
|
||||
'''
|
||||
for k, v in dict2.items():
|
||||
if k not in dict1:
|
||||
dict1[k] = dict()
|
||||
if isinstance(v, dict):
|
||||
update_recursive(dict1[k], v)
|
||||
else:
|
||||
dict1[k] = v
|
||||
|
||||
|
||||
def load_latest_checkpoint(checkpoint_dir):
|
||||
pattern = re.compile(rf".+\.ckpt\.(\d+)\.pt")
|
||||
max_epoch = -1
|
||||
latest_checkpoint = None
|
||||
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
match = pattern.match(filename)
|
||||
if match:
|
||||
num_epoch = int(match.group(1))
|
||||
if num_epoch > max_epoch:
|
||||
max_epoch = num_epoch
|
||||
latest_checkpoint = checkpoint_dir / filename
|
||||
|
||||
if not exists(latest_checkpoint):
|
||||
raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}")
|
||||
|
||||
checkpoint = torch.load(latest_checkpoint)
|
||||
return checkpoint, latest_checkpoint
|
||||
|
||||
|
||||
def torch_to(inp, device, non_blocking=False):
|
||||
nb = non_blocking # set to True when doing distributed jobs
|
||||
if isinstance(inp, torch.Tensor):
|
||||
return inp.to(device, non_blocking=nb)
|
||||
elif isinstance(inp, (list, tuple)):
|
||||
return type(inp)(map(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp))
|
||||
elif isinstance(inp, dict):
|
||||
return valmap(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def first(it):
|
||||
return it[0]
|
||||
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def is_odd(n):
|
||||
return not divisible_by(n, 2)
|
||||
|
||||
def is_empty(x):
|
||||
return len(x) == 0
|
||||
|
||||
def is_tensor_empty(t: Tensor):
|
||||
return t.numel() == 0
|
||||
|
||||
def set_module_requires_grad_(
|
||||
module: Module,
|
||||
requires_grad: bool
|
||||
):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
def l1norm(t):
|
||||
return F.normalize(t, dim = -1, p = 1)
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
def safe_cat(tensors, dim):
|
||||
tensors = [*filter(exists, tensors)]
|
||||
|
||||
if len(tensors) == 0:
|
||||
return None
|
||||
elif len(tensors) == 1:
|
||||
return first(tensors)
|
||||
|
||||
return torch.cat(tensors, dim = dim)
|
||||
|
||||
def pad_at_dim(t, padding, dim = -1, value = 0):
|
||||
ndim = t.ndim
|
||||
right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1)
|
||||
zeros = (0, 0) * right_dims
|
||||
return F.pad(t, (*zeros, *padding), value = value)
|
||||
|
||||
def pad_to_length(t, length, dim = -1, value = 0, right = True):
|
||||
curr_length = t.shape[dim]
|
||||
remainder = length - curr_length
|
||||
|
||||
if remainder <= 0:
|
||||
return t
|
||||
|
||||
padding = (0, remainder) if right else (remainder, 0)
|
||||
return pad_at_dim(t, padding, dim = dim, value = value)
|
||||
|
||||
def masked_mean(tensor, mask, dim = -1, eps = 1e-5):
|
||||
if not exists(mask):
|
||||
return tensor.mean(dim = dim)
|
||||
|
||||
mask = rearrange(mask, '... -> ... 1')
|
||||
tensor = tensor.masked_fill(~mask, 0.)
|
||||
|
||||
total_el = mask.sum(dim = dim)
|
||||
num = tensor.sum(dim = dim)
|
||||
den = total_el.float().clamp(min = eps)
|
||||
mean = num / den
|
||||
mean = mean.masked_fill(total_el == 0, 0.)
|
||||
return mean
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
def maybe_del(d: dict, *keys):
|
||||
for key in keys:
|
||||
if key not in d:
|
||||
continue
|
||||
|
||||
del d[key]
|
||||
|
||||
|
||||
# tensor helper functions
|
||||
|
||||
@typecheck
|
||||
def discretize(
|
||||
t: Tensor,
|
||||
*,
|
||||
continuous_range: Tuple[float, float],
|
||||
num_discrete: int = 128
|
||||
) -> Tensor:
|
||||
lo, hi = continuous_range
|
||||
assert hi > lo
|
||||
|
||||
t = (t - lo) / (hi - lo)
|
||||
t *= num_discrete
|
||||
t -= 0.5
|
||||
|
||||
return t.round().long().clamp(min = 0, max = num_discrete - 1)
|
||||
|
||||
@typecheck
|
||||
def undiscretize(
|
||||
t: Tensor,
|
||||
*,
|
||||
continuous_range = Tuple[float, float],
|
||||
num_discrete: int = 128
|
||||
) -> Tensor:
|
||||
lo, hi = continuous_range
|
||||
assert hi > lo
|
||||
|
||||
t = t.float()
|
||||
|
||||
t += 0.5
|
||||
t /= num_discrete
|
||||
return t * (hi - lo) + lo
|
||||
|
||||
@typecheck
|
||||
def gaussian_blur_1d(
|
||||
t: Tensor,
|
||||
*,
|
||||
sigma: float = 1.,
|
||||
kernel_size: int = 5
|
||||
) -> Tensor:
|
||||
|
||||
_, _, channels, device, dtype = *t.shape, t.device, t.dtype
|
||||
|
||||
width = int(ceil(sigma * kernel_size))
|
||||
width += (width + 1) % 2
|
||||
half_width = width // 2
|
||||
|
||||
distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device)
|
||||
|
||||
gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2))
|
||||
gaussian = l1norm(gaussian)
|
||||
|
||||
kernel = repeat(gaussian, 'n -> c 1 n', c = channels)
|
||||
|
||||
t = rearrange(t, 'b n c -> b c n')
|
||||
out = F.conv1d(t, kernel, padding = half_width, groups = channels)
|
||||
return rearrange(out, 'b c n -> b n c')
|
||||
|
||||
@typecheck
|
||||
def scatter_mean(
|
||||
tgt: Tensor,
|
||||
indices: Tensor,
|
||||
src = Tensor,
|
||||
*,
|
||||
dim: int = -1,
|
||||
eps: float = 1e-5
|
||||
):
|
||||
"""
|
||||
todo: update to pytorch 2.1 and try https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_
|
||||
"""
|
||||
num = tgt.scatter_add(dim, indices, src)
|
||||
den = torch.zeros_like(tgt).scatter_add(dim, indices, torch.ones_like(src))
|
||||
return num / den.clamp(min = eps)
|
||||
|
||||
def prob_mask_like(shape, prob, device):
|
||||
if prob == 1:
|
||||
return torch.ones(shape, device = device, dtype = torch.bool)
|
||||
elif prob == 0:
|
||||
return torch.zeros(shape, device = device, dtype = torch.bool)
|
||||
else:
|
||||
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
|
||||
65
primitive_anything/utils/logger.py
Executable file
65
primitive_anything/utils/logger.py
Executable file
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
|
||||
|
||||
class Verbose:
|
||||
mute = False
|
||||
|
||||
|
||||
def print_log(s, logger=None, level='info'):
|
||||
if Verbose.mute:
|
||||
return None
|
||||
|
||||
if logger is None:
|
||||
logger = logging.getLogger('trainer')
|
||||
if level == 'info':
|
||||
print_info(s)
|
||||
logger.info(s)
|
||||
elif level == 'warning':
|
||||
print_warning(s)
|
||||
logger.warning(s)
|
||||
elif level == 'error':
|
||||
print_error(s)
|
||||
logger.error(s)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def create_logger(log_dir, name='trainer'):
|
||||
assert os.path.exists(log_dir), 'log_dir {} does not exist.'
|
||||
logger = logging.getLogger(name)
|
||||
file_path = log_dir / '{}.log'.format(name)
|
||||
hdlr = logging.FileHandler(file_path)
|
||||
formatter = logging.Formatter('[%(asctime)s] %(levelname)s: %(message)s')
|
||||
hdlr.setFormatter(formatter)
|
||||
logger.addHandler(hdlr)
|
||||
logger.setLevel(logging.INFO)
|
||||
return logger
|
||||
|
||||
|
||||
class TerminalColors:
|
||||
HEADER = '\033[95m'
|
||||
OKBLUE = '\033[94m'
|
||||
OKGREEN = '\033[92m'
|
||||
WARNING = '\033[93m'
|
||||
FAIL = '\033[91m'
|
||||
ENDC = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
UNDERLINE = '\033[4m'
|
||||
|
||||
|
||||
def get_time():
|
||||
return time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
|
||||
def print_info(s):
|
||||
print(TerminalColors.OKBLUE + '[' + get_time() + '] ' + str(s) + TerminalColors.ENDC)
|
||||
|
||||
|
||||
def print_warning(s):
|
||||
print(TerminalColors.WARNING + '[' + get_time() + '] WARN ' + str(s) + TerminalColors.ENDC)
|
||||
|
||||
|
||||
def print_error(s):
|
||||
print(TerminalColors.FAIL + '[' + get_time() + '] ERROR ' + str(s) + TerminalColors.ENDC)
|
||||
57
primitive_anything/utils/typing.py
Executable file
57
primitive_anything/utils/typing.py
Executable file
@@ -0,0 +1,57 @@
|
||||
from environs import Env
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from beartype import beartype
|
||||
from beartype.door import is_bearable
|
||||
|
||||
from jaxtyping import (
|
||||
Float,
|
||||
Int,
|
||||
Bool,
|
||||
jaxtyped
|
||||
)
|
||||
|
||||
# environment
|
||||
|
||||
env = Env()
|
||||
env.read_env()
|
||||
|
||||
# function
|
||||
|
||||
def always(value):
|
||||
def inner(*args, **kwargs):
|
||||
return value
|
||||
return inner
|
||||
|
||||
def identity(t):
|
||||
return t
|
||||
|
||||
# jaxtyping is a misnomer, works for pytorch
|
||||
|
||||
class TorchTyping:
|
||||
def __init__(self, abstract_dtype):
|
||||
self.abstract_dtype = abstract_dtype
|
||||
|
||||
def __getitem__(self, shapes: str):
|
||||
return self.abstract_dtype[Tensor, shapes]
|
||||
|
||||
Float = TorchTyping(Float)
|
||||
Int = TorchTyping(Int)
|
||||
Bool = TorchTyping(Bool)
|
||||
|
||||
# use env variable TYPECHECK to control whether to use beartype + jaxtyping
|
||||
|
||||
should_typecheck = env.bool('TYPECHECK', False)
|
||||
|
||||
typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity
|
||||
|
||||
beartype_isinstance = is_bearable if should_typecheck else always(True)
|
||||
|
||||
__all__ = [
|
||||
Float,
|
||||
Int,
|
||||
Bool,
|
||||
typecheck,
|
||||
beartype_isinstance
|
||||
]
|
||||
Reference in New Issue
Block a user