mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-09-18 05:22:48 +08:00
214 lines
7.2 KiB
Python
Executable File
214 lines
7.2 KiB
Python
Executable File
# -*- coding: utf-8 -*-
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
|
|
VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
|
|
|
|
|
|
class FourierEmbedder(nn.Module):
|
|
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
|
each feature dimension of `x[..., i]` into:
|
|
[
|
|
sin(x[..., i]),
|
|
sin(f_1*x[..., i]),
|
|
sin(f_2*x[..., i]),
|
|
...
|
|
sin(f_N * x[..., i]),
|
|
cos(x[..., i]),
|
|
cos(f_1*x[..., i]),
|
|
cos(f_2*x[..., i]),
|
|
...
|
|
cos(f_N * x[..., i]),
|
|
x[..., i] # only present if include_input is True.
|
|
], here f_i is the frequency.
|
|
|
|
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
|
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
|
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
|
|
|
Args:
|
|
num_freqs (int): the number of frequencies, default is 6;
|
|
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
|
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
|
input_dim (int): the input dimension, default is 3;
|
|
include_input (bool): include the input tensor or not, default is True.
|
|
|
|
Attributes:
|
|
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
|
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
|
|
|
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
|
otherwise, it is input_dim * num_freqs * 2.
|
|
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_freqs: int = 6,
|
|
logspace: bool = True,
|
|
input_dim: int = 3,
|
|
include_input: bool = True,
|
|
include_pi: bool = True) -> None:
|
|
|
|
"""The initialization"""
|
|
|
|
super().__init__()
|
|
|
|
if logspace:
|
|
frequencies = 2.0 ** torch.arange(
|
|
num_freqs,
|
|
dtype=torch.float32
|
|
)
|
|
else:
|
|
frequencies = torch.linspace(
|
|
1.0,
|
|
2.0 ** (num_freqs - 1),
|
|
num_freqs,
|
|
dtype=torch.float32
|
|
)
|
|
|
|
if include_pi:
|
|
frequencies *= torch.pi
|
|
|
|
self.register_buffer("frequencies", frequencies, persistent=False)
|
|
self.include_input = include_input
|
|
self.num_freqs = num_freqs
|
|
|
|
self.out_dim = self.get_dims(input_dim)
|
|
|
|
def get_dims(self, input_dim):
|
|
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
|
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
|
|
|
return out_dim
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
""" Forward process.
|
|
|
|
Args:
|
|
x: tensor of shape [..., dim]
|
|
|
|
Returns:
|
|
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
|
where temp is 1 if include_input is True and 0 otherwise.
|
|
"""
|
|
|
|
if self.num_freqs > 0:
|
|
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
|
|
if self.include_input:
|
|
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
|
else:
|
|
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
|
else:
|
|
return x
|
|
|
|
|
|
class LearnedFourierEmbedder(nn.Module):
|
|
""" following @crowsonkb "s lead with learned sinusoidal pos emb """
|
|
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
|
|
|
|
def __init__(self, in_channels, dim):
|
|
super().__init__()
|
|
assert (dim % 2) == 0
|
|
half_dim = dim // 2
|
|
per_channel_dim = half_dim // in_channels
|
|
self.weights = nn.Parameter(torch.randn(per_channel_dim))
|
|
|
|
def forward(self, x):
|
|
"""
|
|
|
|
Args:
|
|
x (torch.FloatTensor): [..., c]
|
|
|
|
Returns:
|
|
x (torch.FloatTensor): [..., d]
|
|
"""
|
|
|
|
# [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
|
|
freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
|
|
fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
|
|
return fouriered
|
|
|
|
|
|
class TriplaneLearnedFourierEmbedder(nn.Module):
|
|
def __init__(self, in_channels, dim):
|
|
super().__init__()
|
|
|
|
self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
|
|
self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
|
|
self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
|
|
|
|
self.out_dim = in_channels + dim
|
|
|
|
def forward(self, x):
|
|
|
|
yz_embed = self.yz_plane_embedder(x)
|
|
xz_embed = self.xz_plane_embedder(x)
|
|
xy_embed = self.xy_plane_embedder(x)
|
|
|
|
embed = yz_embed + xz_embed + xy_embed
|
|
|
|
return embed
|
|
|
|
|
|
def sequential_pos_embed(num_len, embed_dim):
|
|
assert embed_dim % 2 == 0
|
|
|
|
pos = torch.arange(num_len, dtype=torch.float32)
|
|
omega = torch.arange(embed_dim // 2, dtype=torch.float32)
|
|
omega /= embed_dim / 2.
|
|
omega = 1. / 10000 ** omega # (D/2,)
|
|
|
|
pos = pos.reshape(-1) # (M,)
|
|
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
|
|
emb_sin = torch.sin(out) # (M, D/2)
|
|
emb_cos = torch.cos(out) # (M, D/2)
|
|
|
|
embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
|
|
|
return embeddings
|
|
|
|
|
|
def timestep_embedding(timesteps, dim, max_period=10000):
|
|
"""
|
|
Create sinusoidal timestep embeddings.
|
|
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
|
These may be fractional.
|
|
:param dim: the dimension of the output.
|
|
:param max_period: controls the minimum frequency of the embeddings.
|
|
:return: an [N x dim] Tensor of positional embeddings.
|
|
"""
|
|
half = dim // 2
|
|
freqs = torch.exp(
|
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
|
).to(device=timesteps.device)
|
|
args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
return embedding
|
|
|
|
|
|
def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4,
|
|
num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16,
|
|
log2_hashmap_size=19, desired_resolution=None):
|
|
if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
|
|
return nn.Identity(), input_dim
|
|
|
|
elif embed_type == "fourier":
|
|
embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim,
|
|
logspace=True, include_input=True)
|
|
return embedder_obj, embedder_obj.out_dim
|
|
|
|
elif embed_type == "hashgrid":
|
|
raise NotImplementedError
|
|
|
|
elif embed_type == "sphere_harmonic":
|
|
raise NotImplementedError
|
|
|
|
else:
|
|
raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
|