mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2026-05-08 00:58:55 +08:00
init
This commit is contained in:
423
primitive_anything/michelangelo/models/tsal/sal_perceiver.py
Executable file
423
primitive_anything/michelangelo/models/tsal/sal_perceiver.py
Executable file
@@ -0,0 +1,423 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
from einops import repeat
|
||||
import math
|
||||
|
||||
from ..modules import checkpoint
|
||||
from ..modules.embedder import FourierEmbedder
|
||||
from ..modules.distributions import DiagonalGaussianDistribution
|
||||
from ..modules.transformer_blocks import (
|
||||
ResidualCrossAttentionBlock,
|
||||
Transformer
|
||||
)
|
||||
|
||||
from .tsal_base import ShapeAsLatentModule
|
||||
|
||||
|
||||
class CrossAttentionEncoder(nn.Module):
|
||||
|
||||
def __init__(self, *,
|
||||
device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype],
|
||||
num_latents: int,
|
||||
fourier_embedder: FourierEmbedder,
|
||||
point_feats: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
layers: int,
|
||||
init_scale: float = 0.25,
|
||||
qkv_bias: bool = True,
|
||||
flash: bool = False,
|
||||
use_ln_post: bool = False,
|
||||
use_checkpoint: bool = False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.num_latents = num_latents
|
||||
|
||||
self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
|
||||
|
||||
self.fourier_embedder = fourier_embedder
|
||||
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
|
||||
self.cross_attn = ResidualCrossAttentionBlock(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
width=width,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
)
|
||||
|
||||
self.self_attn = Transformer(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_ctx=num_latents,
|
||||
width=width,
|
||||
layers=layers,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
use_checkpoint=False
|
||||
)
|
||||
|
||||
if use_ln_post:
|
||||
self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
|
||||
else:
|
||||
self.ln_post = None
|
||||
|
||||
def _forward(self, pc, feats):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, C]
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
bs = pc.shape[0]
|
||||
|
||||
data = self.fourier_embedder(pc)
|
||||
if feats is not None:
|
||||
data = torch.cat([data, feats], dim=-1)
|
||||
data = self.input_proj(data)
|
||||
|
||||
query = repeat(self.query, "m c -> b m c", b=bs)
|
||||
latents = self.cross_attn(query, data)
|
||||
latents = self.self_attn(latents)
|
||||
|
||||
if self.ln_post is not None:
|
||||
latents = self.ln_post(latents)
|
||||
|
||||
return latents, pc
|
||||
|
||||
def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, C]
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
|
||||
return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
|
||||
|
||||
|
||||
class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
def __init__(self, *,
|
||||
device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype],
|
||||
num_latents: int,
|
||||
out_channels: int,
|
||||
fourier_embedder: FourierEmbedder,
|
||||
width: int,
|
||||
heads: int,
|
||||
init_scale: float = 0.25,
|
||||
qkv_bias: bool = True,
|
||||
flash: bool = False,
|
||||
use_checkpoint: bool = False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.fourier_embedder = fourier_embedder
|
||||
|
||||
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
|
||||
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_data=num_latents,
|
||||
width=width,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash
|
||||
)
|
||||
|
||||
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
||||
self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
|
||||
|
||||
def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
|
||||
queries = self.query_proj(self.fourier_embedder(queries))
|
||||
x = self.cross_attn_decoder(queries, latents)
|
||||
x = self.ln_post(x)
|
||||
x = self.output_proj(x)
|
||||
return x
|
||||
|
||||
def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
|
||||
return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)
|
||||
|
||||
|
||||
class ShapeAsLatentPerceiver(ShapeAsLatentModule):
|
||||
def __init__(self, *,
|
||||
device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype],
|
||||
num_latents: int,
|
||||
point_feats: int = 0,
|
||||
embed_dim: int = 0,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
init_scale: float = 0.25,
|
||||
qkv_bias: bool = True,
|
||||
flash: bool = False,
|
||||
use_ln_post: bool = False,
|
||||
use_checkpoint: bool = False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
self.num_latents = num_latents
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
init_scale = init_scale * math.sqrt(1.0 / width)
|
||||
self.encoder = CrossAttentionEncoder(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
num_latents=num_latents,
|
||||
point_feats=point_feats,
|
||||
width=width,
|
||||
heads=heads,
|
||||
layers=num_encoder_layers,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
use_ln_post=use_ln_post,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
if embed_dim > 0:
|
||||
# VAE embed
|
||||
self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
|
||||
self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype)
|
||||
self.latent_shape = (num_latents, embed_dim)
|
||||
else:
|
||||
self.latent_shape = (num_latents, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_ctx=num_latents,
|
||||
width=width,
|
||||
layers=num_decoder_layers,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
|
||||
# geometry decoder
|
||||
self.geo_decoder = CrossAttentionDecoder(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
out_channels=1,
|
||||
num_latents=num_latents,
|
||||
width=width,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
|
||||
def encode(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: Optional[torch.FloatTensor] = None,
|
||||
sample_posterior: bool = True):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, C]
|
||||
sample_posterior (bool):
|
||||
|
||||
Returns:
|
||||
latents (torch.FloatTensor)
|
||||
center_pos (torch.FloatTensor or None):
|
||||
posterior (DiagonalGaussianDistribution or None):
|
||||
"""
|
||||
|
||||
latents, center_pos = self.encoder(pc, feats)
|
||||
|
||||
posterior = None
|
||||
if self.embed_dim > 0:
|
||||
moments = self.pre_kl(latents)
|
||||
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
|
||||
|
||||
if sample_posterior:
|
||||
latents = posterior.sample()
|
||||
else:
|
||||
latents = posterior.mode()
|
||||
|
||||
return latents, center_pos, posterior
|
||||
|
||||
def decode(self, latents: torch.FloatTensor):
|
||||
latents = self.post_kl(latents)
|
||||
return self.transformer(latents)
|
||||
|
||||
def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
|
||||
logits = self.geo_decoder(queries, latents).squeeze(-1)
|
||||
return logits
|
||||
|
||||
def forward(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: torch.FloatTensor,
|
||||
volume_queries: torch.FloatTensor,
|
||||
sample_posterior: bool = True):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, C]
|
||||
volume_queries (torch.FloatTensor): [B, P, 3]
|
||||
sample_posterior (bool):
|
||||
|
||||
Returns:
|
||||
logits (torch.FloatTensor): [B, P]
|
||||
center_pos (torch.FloatTensor): [B, M, 3]
|
||||
posterior (DiagonalGaussianDistribution or None).
|
||||
|
||||
"""
|
||||
|
||||
latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
|
||||
|
||||
latents = self.decode(latents)
|
||||
logits = self.query_geometry(volume_queries, latents)
|
||||
|
||||
return logits, center_pos, posterior
|
||||
|
||||
|
||||
class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver):
|
||||
|
||||
def __init__(self, *,
|
||||
device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype],
|
||||
num_latents: int,
|
||||
point_feats: int = 0,
|
||||
embed_dim: int = 0,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
init_scale: float = 0.25,
|
||||
qkv_bias: bool = True,
|
||||
flash: bool = False,
|
||||
use_ln_post: bool = False,
|
||||
use_checkpoint: bool = False):
|
||||
|
||||
super().__init__(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_latents=1 + num_latents,
|
||||
point_feats=point_feats,
|
||||
embed_dim=embed_dim,
|
||||
num_freqs=num_freqs,
|
||||
include_pi=include_pi,
|
||||
width=width,
|
||||
heads=heads,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
use_ln_post=use_ln_post,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
|
||||
self.width = width
|
||||
|
||||
def encode(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: Optional[torch.FloatTensor] = None,
|
||||
sample_posterior: bool = True):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, c]
|
||||
sample_posterior (bool):
|
||||
|
||||
Returns:
|
||||
shape_embed (torch.FloatTensor)
|
||||
kl_embed (torch.FloatTensor):
|
||||
posterior (DiagonalGaussianDistribution or None):
|
||||
"""
|
||||
|
||||
shape_embed, latents = self.encode_latents(pc, feats)
|
||||
kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior)
|
||||
|
||||
return shape_embed, kl_embed, posterior
|
||||
|
||||
def encode_latents(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: Optional[torch.FloatTensor] = None):
|
||||
|
||||
x, _ = self.encoder(pc, feats)
|
||||
|
||||
shape_embed = x[:, 0]
|
||||
latents = x[:, 1:]
|
||||
|
||||
return shape_embed, latents
|
||||
|
||||
def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
|
||||
posterior = None
|
||||
if self.embed_dim > 0:
|
||||
moments = self.pre_kl(latents)
|
||||
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
|
||||
|
||||
if sample_posterior:
|
||||
kl_embed = posterior.sample()
|
||||
else:
|
||||
kl_embed = posterior.mode()
|
||||
else:
|
||||
kl_embed = latents
|
||||
|
||||
return kl_embed, posterior
|
||||
|
||||
def forward(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: torch.FloatTensor,
|
||||
volume_queries: torch.FloatTensor,
|
||||
sample_posterior: bool = True):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, C]
|
||||
volume_queries (torch.FloatTensor): [B, P, 3]
|
||||
sample_posterior (bool):
|
||||
|
||||
Returns:
|
||||
shape_embed (torch.FloatTensor): [B, projection_dim]
|
||||
logits (torch.FloatTensor): [B, M]
|
||||
posterior (DiagonalGaussianDistribution or None).
|
||||
|
||||
"""
|
||||
|
||||
shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
|
||||
|
||||
latents = self.decode(kl_embed)
|
||||
logits = self.query_geometry(volume_queries, latents)
|
||||
|
||||
return shape_embed, logits, posterior
|
||||
Reference in New Issue
Block a user