mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-09-18 05:22:48 +08:00
948 lines
36 KiB
Python
Executable File
948 lines
36 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
from functools import partial
|
|
from math import ceil
|
|
import os
|
|
|
|
from accelerate.utils import DistributedDataParallelKwargs
|
|
from beartype.typing import Tuple, Callable, List
|
|
|
|
from einops import rearrange, repeat, reduce, pack
|
|
from gateloop_transformer import SimpleGateLoopLayer
|
|
from huggingface_hub import PyTorchModelHubMixin
|
|
import numpy as np
|
|
import open3d as o3d
|
|
from tqdm import tqdm
|
|
import torch
|
|
from torch import nn, Tensor
|
|
from torch.nn import Module, ModuleList
|
|
import torch.nn.functional as F
|
|
from pytorch3d.loss import chamfer_distance
|
|
from pytorch3d.transforms import euler_angles_to_matrix
|
|
from x_transformers import Decoder
|
|
from x_transformers.x_transformers import LayerIntermediates
|
|
from x_transformers.autoregressive_wrapper import eval_decorator
|
|
|
|
from .michelangelo import ShapeConditioner as ShapeConditioner_miche
|
|
from .utils import (
|
|
discretize,
|
|
undiscretize,
|
|
set_module_requires_grad_,
|
|
default,
|
|
exists,
|
|
safe_cat,
|
|
identity,
|
|
is_tensor_empty,
|
|
)
|
|
from .utils.typing import Float, Int, Bool, typecheck
|
|
|
|
|
|
# constants
|
|
|
|
DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(
|
|
find_unused_parameters = True
|
|
)
|
|
SHAPE_CODE = {
|
|
'CubeBevel': 0,
|
|
'SphereSharp': 1,
|
|
'CylinderSharp': 2,
|
|
}
|
|
BS_NAME = {
|
|
0: 'CubeBevel',
|
|
1: 'SphereSharp',
|
|
2: 'CylinderSharp',
|
|
}
|
|
|
|
# FiLM block
|
|
|
|
class FiLM(Module):
|
|
def __init__(self, dim, dim_out = None):
|
|
super().__init__()
|
|
dim_out = default(dim_out, dim)
|
|
|
|
self.to_gamma = nn.Linear(dim, dim_out, bias = False)
|
|
self.to_beta = nn.Linear(dim, dim_out)
|
|
|
|
self.gamma_mult = nn.Parameter(torch.zeros(1,))
|
|
self.beta_mult = nn.Parameter(torch.zeros(1,))
|
|
|
|
def forward(self, x, cond):
|
|
gamma, beta = self.to_gamma(cond), self.to_beta(cond)
|
|
gamma, beta = tuple(rearrange(t, 'b d -> b 1 d') for t in (gamma, beta))
|
|
|
|
# for initializing to identity
|
|
|
|
gamma = (1 + self.gamma_mult * gamma.tanh())
|
|
beta = beta.tanh() * self.beta_mult
|
|
|
|
# classic film
|
|
|
|
return x * gamma + beta
|
|
|
|
# gateloop layers
|
|
|
|
class GateLoopBlock(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
*,
|
|
depth,
|
|
use_heinsen = True
|
|
):
|
|
super().__init__()
|
|
self.gateloops = ModuleList([])
|
|
|
|
for _ in range(depth):
|
|
gateloop = SimpleGateLoopLayer(dim = dim, use_heinsen = use_heinsen)
|
|
self.gateloops.append(gateloop)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
cache = None
|
|
):
|
|
received_cache = exists(cache)
|
|
|
|
if is_tensor_empty(x):
|
|
return x, None
|
|
|
|
if received_cache:
|
|
prev, x = x[:, :-1], x[:, -1:]
|
|
|
|
cache = default(cache, [])
|
|
cache = iter(cache)
|
|
|
|
new_caches = []
|
|
for gateloop in self.gateloops:
|
|
layer_cache = next(cache, None)
|
|
out, new_cache = gateloop(x, cache = layer_cache, return_cache = True)
|
|
new_caches.append(new_cache)
|
|
x = x + out
|
|
|
|
if received_cache:
|
|
x = torch.cat((prev, x), dim = -2)
|
|
|
|
return x, new_caches
|
|
|
|
|
|
def top_k_2(logits, frac_num_tokens=0.1, k=None):
|
|
num_tokens = logits.shape[-1]
|
|
|
|
k = default(k, ceil(frac_num_tokens * num_tokens))
|
|
k = min(k, num_tokens)
|
|
|
|
val, ind = torch.topk(logits, k)
|
|
probs = torch.full_like(logits, float('-inf'))
|
|
probs.scatter_(2, ind, val)
|
|
return probs
|
|
|
|
|
|
def soft_argmax(labels):
|
|
indices = torch.arange(labels.size(-1), dtype=labels.dtype, device=labels.device)
|
|
soft_argmax = torch.sum(labels * indices, dim=-1)
|
|
return soft_argmax
|
|
|
|
|
|
class PrimitiveTransformerDiscrete(Module, PyTorchModelHubMixin):
|
|
@typecheck
|
|
def __init__(
|
|
self,
|
|
*,
|
|
num_discrete_scale = 128,
|
|
continuous_range_scale: List[float, float] = [0, 1],
|
|
dim_scale_embed = 64,
|
|
num_discrete_rotation = 180,
|
|
continuous_range_rotation: List[float, float] = [-180, 180],
|
|
dim_rotation_embed = 64,
|
|
num_discrete_translation = 128,
|
|
continuous_range_translation: List[float, float] = [-1, 1],
|
|
dim_translation_embed = 64,
|
|
num_type = 3,
|
|
dim_type_embed = 64,
|
|
embed_order = 'ctrs',
|
|
bin_smooth_blur_sigma = 0.4,
|
|
dim: int | Tuple[int, int] = 512,
|
|
flash_attn = True,
|
|
attn_depth = 12,
|
|
attn_dim_head = 64,
|
|
attn_heads = 16,
|
|
attn_kwargs: dict = dict(
|
|
ff_glu = True,
|
|
attn_num_mem_kv = 4
|
|
),
|
|
max_primitive_len = 144,
|
|
dropout = 0.,
|
|
coarse_pre_gateloop_depth = 2,
|
|
coarse_post_gateloop_depth = 0,
|
|
coarse_adaptive_rmsnorm = False,
|
|
gateloop_use_heinsen = False,
|
|
pad_id = -1,
|
|
num_sos_tokens = None,
|
|
condition_on_shape = True,
|
|
shape_cond_with_cross_attn = False,
|
|
shape_cond_with_film = False,
|
|
shape_cond_with_cat = False,
|
|
shape_condition_model_type = 'michelangelo',
|
|
shape_condition_len = 1,
|
|
shape_condition_dim = None,
|
|
cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out shape condition
|
|
loss_weight: dict = dict(
|
|
eos = 1.0,
|
|
type = 1.0,
|
|
scale = 1.0,
|
|
rotation = 1.0,
|
|
translation = 1.0,
|
|
reconstruction = 1.0,
|
|
scale_huber = 1.0,
|
|
rotation_huber = 1.0,
|
|
translation_huber = 1.0,
|
|
),
|
|
bs_pc_dir=None,
|
|
):
|
|
super().__init__()
|
|
|
|
# feature embedding
|
|
self.num_discrete_scale = num_discrete_scale
|
|
self.continuous_range_scale = continuous_range_scale
|
|
self.discretize_scale = partial(discretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale)
|
|
self.undiscretize_scale = partial(undiscretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale)
|
|
self.scale_embed = nn.Embedding(num_discrete_scale, dim_scale_embed)
|
|
|
|
self.num_discrete_rotation = num_discrete_rotation
|
|
self.continuous_range_rotation = continuous_range_rotation
|
|
self.discretize_rotation = partial(discretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation)
|
|
self.undiscretize_rotation = partial(undiscretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation)
|
|
self.rotation_embed = nn.Embedding(num_discrete_rotation, dim_rotation_embed)
|
|
|
|
self.num_discrete_translation = num_discrete_translation
|
|
self.continuous_range_translation = continuous_range_translation
|
|
self.discretize_translation = partial(discretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation)
|
|
self.undiscretize_translation = partial(undiscretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation)
|
|
self.translation_embed = nn.Embedding(num_discrete_translation, dim_translation_embed)
|
|
|
|
self.num_type = num_type
|
|
self.type_embed = nn.Embedding(num_type, dim_type_embed)
|
|
|
|
self.embed_order = embed_order
|
|
self.bin_smooth_blur_sigma = bin_smooth_blur_sigma
|
|
|
|
# initial dimension
|
|
|
|
self.dim = dim
|
|
init_dim = 3 * (dim_scale_embed + dim_rotation_embed + dim_translation_embed) + dim_type_embed
|
|
|
|
# project into model dimension
|
|
self.project_in = nn.Linear(init_dim, dim)
|
|
|
|
num_sos_tokens = default(num_sos_tokens, 1 if not condition_on_shape or not shape_cond_with_film else 4)
|
|
assert num_sos_tokens > 0
|
|
|
|
self.num_sos_tokens = num_sos_tokens
|
|
self.sos_token = nn.Parameter(torch.randn(num_sos_tokens, dim))
|
|
|
|
# the transformer eos token
|
|
self.eos_token = nn.Parameter(torch.randn(1, dim))
|
|
|
|
self.emb_layernorm = nn.LayerNorm(dim)
|
|
self.max_seq_len = max_primitive_len
|
|
|
|
# shape condition
|
|
|
|
self.condition_on_shape = condition_on_shape
|
|
self.shape_cond_with_cross_attn = False
|
|
self.shape_cond_with_cat = False
|
|
self.shape_condition_model_type = ''
|
|
self.conditioner = None
|
|
dim_shape = None
|
|
|
|
if condition_on_shape:
|
|
assert shape_cond_with_cross_attn or shape_cond_with_film or shape_cond_with_cat
|
|
self.shape_cond_with_cross_attn = shape_cond_with_cross_attn
|
|
self.shape_cond_with_cat = shape_cond_with_cat
|
|
self.shape_condition_model_type = shape_condition_model_type
|
|
if 'michelangelo' in shape_condition_model_type:
|
|
self.conditioner = ShapeConditioner_miche(dim_latent=shape_condition_dim)
|
|
self.to_cond_dim = nn.Linear(self.conditioner.dim_model_out * 2, self.conditioner.dim_latent)
|
|
self.to_cond_dim_head = nn.Linear(self.conditioner.dim_model_out, self.conditioner.dim_latent)
|
|
else:
|
|
raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}')
|
|
|
|
dim_shape = self.conditioner.dim_latent
|
|
set_module_requires_grad_(self.conditioner, False)
|
|
|
|
self.shape_coarse_film_cond = FiLM(dim_shape, dim) if shape_cond_with_film else identity
|
|
|
|
self.coarse_gateloop_block = GateLoopBlock(dim, depth=coarse_pre_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_pre_gateloop_depth > 0 else None
|
|
self.coarse_post_gateloop_block = GateLoopBlock(dim, depth=coarse_post_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_post_gateloop_depth > 0 else None
|
|
self.coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm
|
|
|
|
self.decoder = Decoder(
|
|
dim=dim,
|
|
depth=attn_depth,
|
|
heads=attn_heads,
|
|
attn_dim_head=attn_dim_head,
|
|
attn_flash=flash_attn,
|
|
attn_dropout=dropout,
|
|
ff_dropout=dropout,
|
|
use_adaptive_rmsnorm=coarse_adaptive_rmsnorm,
|
|
dim_condition=dim_shape,
|
|
cross_attend=self.shape_cond_with_cross_attn,
|
|
cross_attn_dim_context=dim_shape,
|
|
cross_attn_num_mem_kv=cross_attn_num_mem_kv,
|
|
**attn_kwargs
|
|
)
|
|
|
|
# to logits
|
|
self.to_eos_logits = nn.Sequential(
|
|
nn.Linear(dim, dim),
|
|
nn.ReLU(),
|
|
nn.Linear(dim, 1)
|
|
)
|
|
self.to_type_logits = nn.Sequential(
|
|
nn.Linear(dim, dim),
|
|
nn.ReLU(),
|
|
nn.Linear(dim, num_type)
|
|
)
|
|
self.to_translation_logits = nn.Sequential(
|
|
nn.Linear(dim + dim_type_embed, dim),
|
|
nn.ReLU(),
|
|
nn.Linear(dim, 3 * num_discrete_translation)
|
|
)
|
|
self.to_rotation_logits = nn.Sequential(
|
|
nn.Linear(dim + dim_type_embed + 3 * dim_translation_embed, dim),
|
|
nn.ReLU(),
|
|
nn.Linear(dim, 3 * num_discrete_rotation)
|
|
)
|
|
self.to_scale_logits = nn.Sequential(
|
|
nn.Linear(dim + dim_type_embed + 3 * (dim_translation_embed + dim_rotation_embed), dim),
|
|
nn.ReLU(),
|
|
nn.Linear(dim, 3 * num_discrete_scale)
|
|
)
|
|
|
|
self.pad_id = pad_id
|
|
|
|
bs_pc_map = {}
|
|
for bs_name, type_code in SHAPE_CODE.items():
|
|
pc = o3d.io.read_point_cloud(os.path.join(bs_pc_dir, f'SM_GR_BS_{bs_name}_001.ply'))
|
|
bs_pc_map[type_code] = torch.from_numpy(np.asarray(pc.points)).float()
|
|
bs_pc_list = []
|
|
for i in range(len(bs_pc_map)):
|
|
bs_pc_list.append(bs_pc_map[i])
|
|
self.bs_pc = torch.stack(bs_pc_list, dim=0)
|
|
|
|
self.rotation_matrix_align_coord = euler_angles_to_matrix(
|
|
torch.Tensor([np.pi/2, 0, 0]), 'XYZ').unsqueeze(0).unsqueeze(0)
|
|
|
|
@property
|
|
def device(self):
|
|
return next(self.parameters()).device
|
|
|
|
@typecheck
|
|
@torch.no_grad()
|
|
def embed_pc(self, pc: Tensor):
|
|
if 'michelangelo' in self.shape_condition_model_type:
|
|
pc_head, pc_embed = self.conditioner(shape=pc)
|
|
pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2).detach()
|
|
else:
|
|
raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}')
|
|
|
|
return pc_embed
|
|
|
|
@typecheck
|
|
def recon_primitives(
|
|
self,
|
|
scale_logits: Float['b np 3 nd'],
|
|
rotation_logits: Float['b np 3 nd'],
|
|
translation_logits: Float['b np 3 nd'],
|
|
type_logits: Int['b np nd'],
|
|
primitive_mask: Bool['b np']
|
|
):
|
|
recon_scale = self.undiscretize_scale(scale_logits.argmax(dim=-1))
|
|
recon_scale = recon_scale.masked_fill(~primitive_mask.unsqueeze(-1), float('nan'))
|
|
recon_rotation = self.undiscretize_rotation(rotation_logits.argmax(dim=-1))
|
|
recon_rotation = recon_rotation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan'))
|
|
recon_translation = self.undiscretize_translation(translation_logits.argmax(dim=-1))
|
|
recon_translation = recon_translation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan'))
|
|
recon_type_code = type_logits.argmax(dim=-1)
|
|
recon_type_code = recon_type_code.masked_fill(~primitive_mask, -1)
|
|
|
|
return {
|
|
'scale': recon_scale,
|
|
'rotation': recon_rotation,
|
|
'translation': recon_translation,
|
|
'type_code': recon_type_code
|
|
}
|
|
|
|
@typecheck
|
|
def sample_primitives(
|
|
self,
|
|
scale: Float['b np 3 nd'],
|
|
rotation: Float['b np 3 nd'],
|
|
translation: Float['b np 3 nd'],
|
|
type_code: Int['b np nd'],
|
|
next_embed: Float['b 1 nd'],
|
|
temperature: float = 1.,
|
|
filter_logits_fn: Callable = top_k_2,
|
|
filter_kwargs: dict = dict()
|
|
):
|
|
def sample_func(logits):
|
|
if logits.ndim == 4:
|
|
enable_squeeze = True
|
|
logits = logits.squeeze(1)
|
|
else:
|
|
enable_squeeze = False
|
|
|
|
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
|
|
|
if temperature == 0.:
|
|
sample = filtered_logits.argmax(dim=-1)
|
|
else:
|
|
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
|
|
|
sample = torch.zeros((probs.shape[0], probs.shape[1]), dtype=torch.long, device=probs.device)
|
|
for b_i in range(probs.shape[0]):
|
|
sample[b_i] = torch.multinomial(probs[b_i], 1).squeeze()
|
|
|
|
if enable_squeeze:
|
|
sample = sample.unsqueeze(1)
|
|
|
|
return sample
|
|
|
|
next_type_logits = self.to_type_logits(next_embed)
|
|
next_type_code = sample_func(next_type_logits)
|
|
type_code_new, _ = pack([type_code, next_type_code], 'b *')
|
|
|
|
type_embed = self.type_embed(next_type_code)
|
|
next_embed_packed, _ = pack([next_embed, type_embed], 'b np *')
|
|
next_translation_logits = rearrange(self.to_translation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_translation)
|
|
next_discretize_translation = sample_func(next_translation_logits)
|
|
next_translation = self.undiscretize_translation(next_discretize_translation)
|
|
translation_new, _ = pack([translation, next_translation], 'b * nd')
|
|
|
|
next_translation_embed = self.translation_embed(next_discretize_translation)
|
|
next_embed_packed, _ = pack([next_embed_packed, next_translation_embed], 'b np *')
|
|
next_rotation_logits = rearrange(self.to_rotation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_rotation)
|
|
next_discretize_rotation = sample_func(next_rotation_logits)
|
|
next_rotation = self.undiscretize_rotation(next_discretize_rotation)
|
|
rotation_new, _ = pack([rotation, next_rotation], 'b * nd')
|
|
|
|
next_rotation_embed = self.rotation_embed(next_discretize_rotation)
|
|
next_embed_packed, _ = pack([next_embed_packed, next_rotation_embed], 'b np *')
|
|
next_scale_logits = rearrange(self.to_scale_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_scale)
|
|
next_discretize_scale = sample_func(next_scale_logits)
|
|
next_scale = self.undiscretize_scale(next_discretize_scale)
|
|
scale_new, _ = pack([scale, next_scale], 'b * nd')
|
|
|
|
return (
|
|
scale_new,
|
|
rotation_new,
|
|
translation_new,
|
|
type_code_new
|
|
)
|
|
|
|
@eval_decorator
|
|
@torch.no_grad()
|
|
@typecheck
|
|
def generate(
|
|
self,
|
|
batch_size: int | None = None,
|
|
filter_logits_fn: Callable = top_k_2,
|
|
filter_kwargs: dict = dict(),
|
|
temperature: float = 1.,
|
|
scale: Float['b np 3'] | None = None,
|
|
rotation: Float['b np 3'] | None = None,
|
|
translation: Float['b np 3'] | None = None,
|
|
type_code: Int['b np'] | None = None,
|
|
pc: Tensor | None = None,
|
|
pc_embed: Tensor | None = None,
|
|
cache_kv = True,
|
|
max_seq_len = None,
|
|
):
|
|
max_seq_len = default(max_seq_len, self.max_seq_len)
|
|
|
|
if exists(scale) and exists(rotation) and exists(translation) and exists(type_code):
|
|
assert not exists(batch_size)
|
|
assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1]
|
|
assert scale.shape[1] <= self.max_seq_len
|
|
|
|
batch_size = scale.shape[0]
|
|
|
|
if self.condition_on_shape:
|
|
assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in'
|
|
if exists(pc):
|
|
pc_embed = self.embed_pc(pc)
|
|
|
|
batch_size = default(batch_size, pc_embed.shape[0])
|
|
|
|
batch_size = default(batch_size, 1)
|
|
|
|
scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device))
|
|
rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device))
|
|
translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device))
|
|
type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device))
|
|
|
|
curr_length = scale.shape[1]
|
|
|
|
cache = None
|
|
eos_codes = None
|
|
|
|
for i in tqdm(range(curr_length, max_seq_len)):
|
|
can_eos = i != 0
|
|
output = self.forward(
|
|
scale=scale,
|
|
rotation=rotation,
|
|
translation=translation,
|
|
type_code=type_code,
|
|
pc_embed=pc_embed,
|
|
return_loss=False,
|
|
return_cache=cache_kv,
|
|
append_eos=False,
|
|
cache=cache
|
|
)
|
|
if cache_kv:
|
|
next_embed, cache = output
|
|
else:
|
|
next_embed = output
|
|
(
|
|
scale,
|
|
rotation,
|
|
translation,
|
|
type_code
|
|
) = self.sample_primitives(
|
|
scale,
|
|
rotation,
|
|
translation,
|
|
type_code,
|
|
next_embed,
|
|
temperature=temperature,
|
|
filter_logits_fn=filter_logits_fn,
|
|
filter_kwargs=filter_kwargs
|
|
)
|
|
|
|
next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1)
|
|
next_eos_code = (F.sigmoid(next_eos_logits) > 0.5)
|
|
eos_codes = safe_cat([eos_codes, next_eos_code], 1)
|
|
if can_eos and eos_codes.any(dim=-1).all():
|
|
break
|
|
|
|
# mask out to padding anything after the first eos
|
|
mask = eos_codes.float().cumsum(dim=-1) >= 1
|
|
|
|
# concat cur_length to mask
|
|
mask = torch.cat((torch.zeros((batch_size, curr_length), dtype=torch.bool, device=self.device), mask), dim=-1)
|
|
type_code = type_code.masked_fill(mask, self.pad_id)
|
|
scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id)
|
|
rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id)
|
|
translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id)
|
|
|
|
recon_primitives = {
|
|
'scale': scale,
|
|
'rotation': rotation,
|
|
'translation': translation,
|
|
'type_code': type_code
|
|
}
|
|
primitive_mask = ~eos_codes
|
|
|
|
return recon_primitives, primitive_mask
|
|
|
|
|
|
@eval_decorator
|
|
@torch.no_grad()
|
|
@typecheck
|
|
def generate_w_recon_loss(
|
|
self,
|
|
batch_size: int | None = None,
|
|
filter_logits_fn: Callable = top_k_2,
|
|
filter_kwargs: dict = dict(),
|
|
temperature: float = 1.,
|
|
scale: Float['b np 3'] | None = None,
|
|
rotation: Float['b np 3'] | None = None,
|
|
translation: Float['b np 3'] | None = None,
|
|
type_code: Int['b np'] | None = None,
|
|
pc: Tensor | None = None,
|
|
pc_embed: Tensor | None = None,
|
|
cache_kv = True,
|
|
max_seq_len = None,
|
|
single_directional = True,
|
|
):
|
|
max_seq_len = default(max_seq_len, self.max_seq_len)
|
|
|
|
if exists(scale) and exists(rotation) and exists(translation) and exists(type_code):
|
|
assert not exists(batch_size)
|
|
assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1]
|
|
assert scale.shape[1] <= self.max_seq_len
|
|
|
|
batch_size = scale.shape[0]
|
|
|
|
if self.condition_on_shape:
|
|
assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in'
|
|
if exists(pc):
|
|
pc_embed = self.embed_pc(pc)
|
|
|
|
batch_size = default(batch_size, pc_embed.shape[0])
|
|
|
|
batch_size = default(batch_size, 1)
|
|
assert batch_size == 1 # TODO: support any batch size
|
|
|
|
scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device))
|
|
rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device))
|
|
translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device))
|
|
type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device))
|
|
|
|
curr_length = scale.shape[1]
|
|
|
|
cache = None
|
|
eos_codes = None
|
|
last_recon_loss = 1
|
|
for i in tqdm(range(curr_length, max_seq_len)):
|
|
can_eos = i != 0
|
|
output = self.forward(
|
|
scale=scale,
|
|
rotation=rotation,
|
|
translation=translation,
|
|
type_code=type_code,
|
|
pc_embed=pc_embed,
|
|
return_loss=False,
|
|
return_cache=cache_kv,
|
|
append_eos=False,
|
|
cache=cache
|
|
)
|
|
if cache_kv:
|
|
next_embed, cache = output
|
|
else:
|
|
next_embed = output
|
|
(
|
|
scale_new,
|
|
rotation_new,
|
|
translation_new,
|
|
type_code_new
|
|
) = self.sample_primitives(
|
|
scale,
|
|
rotation,
|
|
translation,
|
|
type_code,
|
|
next_embed,
|
|
temperature=temperature,
|
|
filter_logits_fn=filter_logits_fn,
|
|
filter_kwargs=filter_kwargs
|
|
)
|
|
|
|
next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1)
|
|
next_eos_code = (F.sigmoid(next_eos_logits) > 0.5)
|
|
eos_codes = safe_cat([eos_codes, next_eos_code], 1)
|
|
if can_eos and eos_codes.any(dim=-1).all():
|
|
scale, rotation, translation, type_code = (
|
|
scale_new, rotation_new, translation_new, type_code_new)
|
|
break
|
|
|
|
recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc, single_directional)
|
|
if recon_loss < last_recon_loss:
|
|
last_recon_loss = recon_loss
|
|
scale, rotation, translation, type_code = (
|
|
scale_new, rotation_new, translation_new, type_code_new)
|
|
else:
|
|
best_recon_loss = recon_loss
|
|
best_primitives = dict(
|
|
scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new)
|
|
success_flag = False
|
|
print(f'last_recon_loss:{last_recon_loss}, recon_loss:{recon_loss} -> to find better primitive')
|
|
for try_i in range(5):
|
|
(
|
|
scale_new,
|
|
rotation_new,
|
|
translation_new,
|
|
type_code_new
|
|
) = self.sample_primitives(
|
|
scale,
|
|
rotation,
|
|
translation,
|
|
type_code,
|
|
next_embed,
|
|
temperature=1.0,
|
|
filter_logits_fn=filter_logits_fn,
|
|
filter_kwargs=filter_kwargs
|
|
)
|
|
recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc)
|
|
print(f'[try_{try_i}] last_recon_loss:{last_recon_loss}, best_recon_loss:{best_recon_loss}, cur_recon_loss:{recon_loss}')
|
|
if recon_loss < last_recon_loss:
|
|
last_recon_loss = recon_loss
|
|
scale, rotation, translation, type_code = (
|
|
scale_new, rotation_new, translation_new, type_code_new)
|
|
success_flag = True
|
|
break
|
|
else:
|
|
if recon_loss < best_recon_loss:
|
|
best_recon_loss = recon_loss
|
|
best_primitives = dict(
|
|
scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new)
|
|
|
|
if not success_flag:
|
|
last_recon_loss = best_recon_loss
|
|
scale, rotation, translation, type_code = (
|
|
best_primitives['scale'], best_primitives['rotation'], best_primitives['translation'], best_primitives['type_code'])
|
|
print(f'new_last_recon_loss:{last_recon_loss}')
|
|
|
|
# mask out to padding anything after the first eos
|
|
mask = eos_codes.float().cumsum(dim=-1) >= 1
|
|
type_code = type_code.masked_fill(mask, self.pad_id)
|
|
scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id)
|
|
rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id)
|
|
translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id)
|
|
|
|
recon_primitives = {
|
|
'scale': scale,
|
|
'rotation': rotation,
|
|
'translation': translation,
|
|
'type_code': type_code
|
|
}
|
|
primitive_mask = ~eos_codes
|
|
|
|
return recon_primitives, primitive_mask
|
|
|
|
|
|
@typecheck
|
|
def encode(
|
|
self,
|
|
*,
|
|
scale: Float['b np 3'],
|
|
rotation: Float['b np 3'],
|
|
translation: Float['b np 3'],
|
|
type_code: Int['b np'],
|
|
primitive_mask: Bool['b np'],
|
|
return_primitives = False
|
|
):
|
|
"""
|
|
einops:
|
|
b - batch
|
|
np - number of primitives
|
|
c - coordinates (3)
|
|
d - embed dim
|
|
"""
|
|
|
|
# compute feature embedding
|
|
discretize_scale = self.discretize_scale(scale)
|
|
scale_embed = self.scale_embed(discretize_scale)
|
|
scale_embed = rearrange(scale_embed, 'b np c d -> b np (c d)')
|
|
|
|
discretize_rotation = self.discretize_rotation(rotation)
|
|
rotation_embed = self.rotation_embed(discretize_rotation)
|
|
rotation_embed = rearrange(rotation_embed, 'b np c d -> b np (c d)')
|
|
|
|
discretize_translation = self.discretize_translation(translation)
|
|
translation_embed = self.translation_embed(discretize_translation)
|
|
translation_embed = rearrange(translation_embed, 'b np c d -> b np (c d)')
|
|
|
|
type_embed = self.type_embed(type_code.masked_fill(~primitive_mask, 0))
|
|
|
|
# combine all features and project into model dimension
|
|
if self.embed_order == 'srtc':
|
|
primitive_embed, _ = pack([scale_embed, rotation_embed, translation_embed, type_embed], 'b np *')
|
|
else:
|
|
primitive_embed, _ = pack([type_embed, translation_embed, rotation_embed, scale_embed], 'b np *')
|
|
|
|
primitive_embed = self.project_in(primitive_embed)
|
|
primitive_embed = primitive_embed.masked_fill(~primitive_mask.unsqueeze(-1), 0.)
|
|
|
|
if not return_primitives:
|
|
return primitive_embed
|
|
|
|
primitive_embed_unpacked = {
|
|
'scale': scale_embed,
|
|
'rotation': rotation_embed,
|
|
'translation': translation_embed,
|
|
'type_code': type_embed
|
|
}
|
|
|
|
primitives_gt = {
|
|
'scale': discretize_scale,
|
|
'rotation': discretize_rotation,
|
|
'translation': discretize_translation,
|
|
'type_code': type_code
|
|
}
|
|
|
|
return primitive_embed, primitive_embed_unpacked, primitives_gt
|
|
|
|
@typecheck
|
|
def compute_chamfer_distance(
|
|
self,
|
|
scale_pred: Float['b np 3'],
|
|
rotation_pred: Float['b np 3'],
|
|
translation_pred: Float['b np 3'],
|
|
type_pred: Int['b np'],
|
|
primitive_mask: Bool['b np'],
|
|
pc: Tensor, # b, num_points, c
|
|
single_directional = True
|
|
):
|
|
scale_pred = scale_pred.float()
|
|
rotation_pred = rotation_pred.float()
|
|
translation_pred = translation_pred.float()
|
|
|
|
pc_pred = apply_transformation(self.bs_pc.to(type_pred.device)[type_pred], scale_pred, torch.deg2rad(rotation_pred), translation_pred)
|
|
pc_pred = torch.matmul(pc_pred, self.rotation_matrix_align_coord.to(type_pred.device))
|
|
pc_pred_flat = rearrange(pc_pred, 'b np p c -> b (np p) c')
|
|
pc_pred_sampled = random_sample_pc(pc_pred_flat, primitive_mask.sum(dim=-1, keepdim=True), n_points=self.bs_pc.shape[1])
|
|
|
|
if single_directional:
|
|
recon_loss, _ = chamfer_distance(pc[:, :, :3].float(), pc_pred_sampled.float(), single_directional=True) # single directional
|
|
else:
|
|
recon_loss, _ = chamfer_distance(pc_pred_sampled.float(), pc[:, :, :3].float())
|
|
|
|
return recon_loss
|
|
|
|
def forward(
|
|
self,
|
|
*,
|
|
scale: Float['b np 3'],
|
|
rotation: Float['b np 3'],
|
|
translation: Float['b np 3'],
|
|
type_code: Int['b np'],
|
|
loss_reduction: str = 'mean',
|
|
return_cache = False,
|
|
append_eos = True,
|
|
cache: LayerIntermediates | None = None,
|
|
pc: Tensor | None = None,
|
|
pc_embed: Tensor | None = None,
|
|
**kwargs
|
|
):
|
|
|
|
primitive_mask = reduce(scale != self.pad_id, 'b np 3 -> b np', 'all')
|
|
|
|
if scale.shape[1] > 0:
|
|
codes, primitives_embeds, primitives_gt = self.encode(
|
|
scale=scale,
|
|
rotation=rotation,
|
|
translation=translation,
|
|
type_code=type_code,
|
|
primitive_mask=primitive_mask,
|
|
return_primitives=True
|
|
)
|
|
else:
|
|
codes = torch.empty((scale.shape[0], 0, self.dim), dtype=torch.float32, device=self.device)
|
|
|
|
# handle shape conditions
|
|
|
|
attn_context_kwargs = dict()
|
|
|
|
if self.condition_on_shape:
|
|
assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in'
|
|
|
|
if exists(pc):
|
|
if 'michelangelo' in self.shape_condition_model_type:
|
|
pc_head, pc_embed = self.conditioner(shape=pc)
|
|
pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2)
|
|
else:
|
|
raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}')
|
|
|
|
assert pc_embed.shape[0] == codes.shape[0], 'batch size of point cloud is not equal to the batch size of the primitive codes'
|
|
|
|
pooled_pc_embed = pc_embed.mean(dim=1) # (b, shape_condition_dim)
|
|
|
|
if self.shape_cond_with_cross_attn:
|
|
attn_context_kwargs = dict(
|
|
context=pc_embed
|
|
)
|
|
|
|
if self.coarse_adaptive_rmsnorm:
|
|
attn_context_kwargs.update(
|
|
condition=pooled_pc_embed
|
|
)
|
|
|
|
batch, seq_len, _ = codes.shape # (b, np, dim)
|
|
device = codes.device
|
|
assert seq_len <= self.max_seq_len, f'received codes of length {seq_len} but needs to be less than or equal to set max_seq_len {self.max_seq_len}'
|
|
|
|
if append_eos:
|
|
assert exists(codes)
|
|
code_lens = primitive_mask.sum(dim=-1)
|
|
codes = pad_tensor(codes)
|
|
|
|
batch_arange = torch.arange(batch, device=device)
|
|
batch_arange = rearrange(batch_arange, '... -> ... 1')
|
|
code_lens = rearrange(code_lens, '... -> ... 1')
|
|
codes[batch_arange, code_lens] = self.eos_token # (b, np+1, dim)
|
|
|
|
primitive_codes = codes # (b, np, dim)
|
|
|
|
primitive_codes_len = primitive_codes.shape[-2]
|
|
|
|
(
|
|
coarse_cache,
|
|
coarse_gateloop_cache,
|
|
coarse_post_gateloop_cache,
|
|
) = cache if exists(cache) else ((None,) * 3)
|
|
|
|
if not exists(cache):
|
|
sos = repeat(self.sos_token, 'n d -> b n d', b=batch)
|
|
|
|
if self.shape_cond_with_cat:
|
|
sos, _ = pack([pc_embed, sos], 'b * d')
|
|
primitive_codes, packed_sos_shape = pack([sos, primitive_codes], 'b * d') # (b, n_sos+np, dim)
|
|
|
|
# condition primitive codes with shape if needed
|
|
if self.condition_on_shape:
|
|
primitive_codes = self.shape_coarse_film_cond(primitive_codes, pooled_pc_embed)
|
|
|
|
# attention on primitive codes (coarse)
|
|
|
|
if exists(self.coarse_gateloop_block):
|
|
primitive_codes, coarse_gateloop_cache = self.coarse_gateloop_block(primitive_codes, cache=coarse_gateloop_cache)
|
|
|
|
attended_primitive_codes, coarse_cache = self.decoder( # (b, n_sos+np, dim)
|
|
primitive_codes,
|
|
cache=coarse_cache,
|
|
return_hiddens=True,
|
|
**attn_context_kwargs
|
|
)
|
|
|
|
if exists(self.coarse_post_gateloop_block):
|
|
primitive_codes, coarse_post_gateloop_cache = self.coarse_post_gateloop_block(primitive_codes, cache=coarse_post_gateloop_cache)
|
|
|
|
embed = attended_primitive_codes[:, -(primitive_codes_len + 1):] # (b, np+1, dim)
|
|
|
|
if not return_cache:
|
|
return embed[:, -1:]
|
|
|
|
next_cache = (
|
|
coarse_cache,
|
|
coarse_gateloop_cache,
|
|
coarse_post_gateloop_cache
|
|
)
|
|
|
|
return embed[:, -1:], next_cache
|
|
|
|
|
|
def pad_tensor(tensor):
|
|
if tensor.dim() == 3:
|
|
bs, seq_len, dim = tensor.shape
|
|
padding = torch.zeros((bs, 1, dim), dtype=tensor.dtype, device=tensor.device)
|
|
elif tensor.dim() == 2:
|
|
bs, seq_len = tensor.shape
|
|
padding = torch.zeros((bs, 1), dtype=tensor.dtype, device=tensor.device)
|
|
else:
|
|
raise ValueError('Unsupported tensor shape: {}'.format(tensor.shape))
|
|
|
|
return torch.cat([tensor, padding], dim=1)
|
|
|
|
|
|
def apply_transformation(pc, scale, rotation_vector, translation):
|
|
bs, np, num_points, _ = pc.shape
|
|
scaled_pc = pc * scale.unsqueeze(2)
|
|
|
|
rotation_matrix = euler_angles_to_matrix(rotation_vector.view(-1, 3), 'XYZ').view(bs, np, 3, 3) # euler tmp
|
|
rotated_pc = torch.einsum('bnij,bnpj->bnpi', rotation_matrix, scaled_pc)
|
|
|
|
transformed_pc = rotated_pc + translation.unsqueeze(2)
|
|
|
|
return transformed_pc
|
|
|
|
|
|
def random_sample_pc(pc, max_lens, n_points=10000):
|
|
bs = max_lens.shape[0]
|
|
max_len = max_lens.max().item() * n_points
|
|
|
|
random_values = torch.rand(bs, max_len, device=max_lens.device)
|
|
mask = torch.arange(max_len).expand(bs, max_len).to(max_lens.device) < (max_lens * n_points)
|
|
masked_random_values = random_values * mask.float()
|
|
_, indices = torch.topk(masked_random_values, n_points, dim=1)
|
|
|
|
return pc[torch.arange(bs).unsqueeze(1), indices] |