mirror of
https://github.com/PrimitiveAnything/PrimitiveAnything.git
synced 2025-12-28 11:00:33 +08:00
init
This commit is contained in:
1
primitive_anything/michelangelo/models/tsal/__init__.py
Executable file
1
primitive_anything/michelangelo/models/tsal/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
373
primitive_anything/michelangelo/models/tsal/asl_pl_module.py
Executable file
373
primitive_anything/michelangelo/models/tsal/asl_pl_module.py
Executable file
@@ -0,0 +1,373 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
from omegaconf import DictConfig
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import lr_scheduler
|
||||
import pytorch_lightning as pl
|
||||
from typing import Union
|
||||
from functools import partial
|
||||
|
||||
from ...utils import instantiate_from_config
|
||||
|
||||
from .inference_utils import extract_geometry
|
||||
from .tsal_base import (
|
||||
AlignedShapeAsLatentModule,
|
||||
ShapeAsLatentModule,
|
||||
Latent2MeshOutput,
|
||||
AlignedMeshOutput
|
||||
)
|
||||
|
||||
|
||||
class AlignedShapeAsLatentPLModule(pl.LightningModule):
|
||||
|
||||
def __init__(self, *,
|
||||
shape_module_cfg,
|
||||
aligned_module_cfg,
|
||||
loss_cfg,
|
||||
optimizer_cfg: Optional[DictConfig] = None,
|
||||
ckpt_path: Optional[str] = None,
|
||||
ignore_keys: Union[Tuple[str], List[str]] = ()):
|
||||
|
||||
super().__init__()
|
||||
|
||||
shape_model: ShapeAsLatentModule = instantiate_from_config(
|
||||
shape_module_cfg, device=None, dtype=None
|
||||
)
|
||||
self.model: AlignedShapeAsLatentModule = instantiate_from_config(
|
||||
aligned_module_cfg, shape_model=shape_model
|
||||
)
|
||||
|
||||
self.loss = instantiate_from_config(loss_cfg)
|
||||
|
||||
self.optimizer_cfg = optimizer_cfg
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
self.save_hyperparameters()
|
||||
|
||||
def set_shape_model_only(self):
|
||||
self.model.set_shape_model_only()
|
||||
|
||||
@property
|
||||
def latent_shape(self):
|
||||
return self.model.shape_model.latent_shape
|
||||
|
||||
@property
|
||||
def zero_rank(self):
|
||||
if self._trainer:
|
||||
zero_rank = self.trainer.local_rank == 0
|
||||
else:
|
||||
zero_rank = True
|
||||
|
||||
return zero_rank
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=()):
|
||||
state_dict = torch.load(path, map_location="cpu")["state_dict"]
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del state_dict[k]
|
||||
|
||||
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def configure_optimizers(self) -> Tuple[List, List]:
|
||||
lr = self.learning_rate
|
||||
|
||||
trainable_parameters = list(self.model.parameters())
|
||||
|
||||
if self.optimizer_cfg is None:
|
||||
optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
|
||||
schedulers = []
|
||||
else:
|
||||
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
|
||||
scheduler_func = instantiate_from_config(
|
||||
self.optimizer_cfg.scheduler,
|
||||
max_decay_steps=self.trainer.max_steps,
|
||||
lr_max=lr
|
||||
)
|
||||
scheduler = {
|
||||
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1
|
||||
}
|
||||
optimizers = [optimizer]
|
||||
schedulers = [scheduler]
|
||||
|
||||
return optimizers, schedulers
|
||||
|
||||
def forward(self,
|
||||
surface: torch.FloatTensor,
|
||||
image: torch.FloatTensor,
|
||||
text: torch.FloatTensor,
|
||||
volume_queries: torch.FloatTensor):
|
||||
|
||||
"""
|
||||
|
||||
Args:
|
||||
surface (torch.FloatTensor):
|
||||
image (torch.FloatTensor):
|
||||
text (torch.FloatTensor):
|
||||
volume_queries (torch.FloatTensor):
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
embed_outputs, shape_z = self.model(surface, image, text)
|
||||
|
||||
shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
|
||||
latents = self.model.shape_model.decode(shape_zq)
|
||||
logits = self.model.shape_model.query_geometry(volume_queries, latents)
|
||||
|
||||
return embed_outputs, logits, posterior
|
||||
|
||||
def encode(self, surface: torch.FloatTensor, sample_posterior=True):
|
||||
|
||||
pc = surface[..., 0:3]
|
||||
feats = surface[..., 3:6]
|
||||
|
||||
shape_embed, shape_zq, posterior = self.model.shape_model.encode(
|
||||
pc=pc, feats=feats, sample_posterior=sample_posterior
|
||||
)
|
||||
|
||||
return shape_zq
|
||||
|
||||
def encode_latents(self, surface: torch.FloatTensor):
|
||||
|
||||
pc = surface[..., 0:3]
|
||||
feats = surface[..., 3:6]
|
||||
|
||||
shape_embed, shape_latents = self.model.shape_model.encode_latents(
|
||||
pc=pc, feats=feats
|
||||
)
|
||||
shape_embed = shape_embed.unsqueeze(1)
|
||||
assert shape_embed.shape[1] == 1 and shape_latents.shape[1] == 256
|
||||
cat_latents = torch.cat([shape_embed, shape_latents], dim=1)
|
||||
|
||||
return cat_latents
|
||||
|
||||
def to_shape_latents(self, latents):
|
||||
|
||||
shape_zq, posterior = self.model.shape_model.encode_kl_embed(latents, sample_posterior = False)
|
||||
return self.model.shape_model.decode(shape_zq)
|
||||
|
||||
def decode(self,
|
||||
z_q,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.1,
|
||||
octree_depth: int = 7,
|
||||
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
|
||||
|
||||
latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim]
|
||||
outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
|
||||
|
||||
return outputs
|
||||
|
||||
def training_step(self, batch: Dict[str, torch.FloatTensor],
|
||||
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
batch (dict): the batch sample, and it contains:
|
||||
- surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
|
||||
- image (torch.FloatTensor): [bs, 3, 224, 224]
|
||||
- text (torch.FloatTensor): [bs, num_templates, 77]
|
||||
- geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
|
||||
|
||||
batch_idx (int):
|
||||
|
||||
optimizer_idx (int):
|
||||
|
||||
Returns:
|
||||
loss (torch.FloatTensor):
|
||||
|
||||
"""
|
||||
|
||||
surface = batch["surface"]
|
||||
image = batch["image"]
|
||||
text = batch["text"]
|
||||
|
||||
volume_queries = batch["geo_points"][..., 0:3]
|
||||
shape_labels = batch["geo_points"][..., -1]
|
||||
|
||||
embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
|
||||
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
**embed_outputs,
|
||||
posteriors=posteriors,
|
||||
shape_logits=shape_logits,
|
||||
shape_labels=shape_labels,
|
||||
split="train"
|
||||
)
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
|
||||
sync_dist=False, rank_zero_only=True)
|
||||
|
||||
return aeloss
|
||||
|
||||
def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
|
||||
|
||||
surface = batch["surface"]
|
||||
image = batch["image"]
|
||||
text = batch["text"]
|
||||
|
||||
volume_queries = batch["geo_points"][..., 0:3]
|
||||
shape_labels = batch["geo_points"][..., -1]
|
||||
|
||||
embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
|
||||
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
**embed_outputs,
|
||||
posteriors=posteriors,
|
||||
shape_logits=shape_logits,
|
||||
shape_labels=shape_labels,
|
||||
split="val"
|
||||
)
|
||||
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
|
||||
sync_dist=False, rank_zero_only=True)
|
||||
|
||||
return aeloss
|
||||
|
||||
def visual_alignment(self,
|
||||
surface: torch.FloatTensor,
|
||||
image: torch.FloatTensor,
|
||||
text: torch.FloatTensor,
|
||||
description: Optional[List[str]] = None,
|
||||
bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
|
||||
octree_depth: int = 7,
|
||||
num_chunks: int = 10000) -> List[AlignedMeshOutput]:
|
||||
|
||||
"""
|
||||
|
||||
Args:
|
||||
surface:
|
||||
image:
|
||||
text:
|
||||
description:
|
||||
bounds:
|
||||
octree_depth:
|
||||
num_chunks:
|
||||
|
||||
Returns:
|
||||
mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list.
|
||||
|
||||
"""
|
||||
|
||||
outputs = []
|
||||
|
||||
device = surface.device
|
||||
bs = surface.shape[0]
|
||||
|
||||
embed_outputs, shape_z = self.model(surface, image, text)
|
||||
|
||||
# calculate the similarity
|
||||
image_embed = embed_outputs["image_embed"]
|
||||
text_embed = embed_outputs["text_embed"]
|
||||
shape_embed = embed_outputs["shape_embed"]
|
||||
|
||||
# normalized features
|
||||
shape_embed = F.normalize(shape_embed, dim=-1, p=2)
|
||||
text_embed = F.normalize(text_embed, dim=-1, p=2)
|
||||
image_embed = F.normalize(image_embed, dim=-1, p=2)
|
||||
|
||||
# B x B
|
||||
shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1)
|
||||
|
||||
# B x B
|
||||
shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1)
|
||||
|
||||
# shape reconstruction
|
||||
shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
|
||||
latents = self.model.shape_model.decode(shape_zq)
|
||||
geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
|
||||
|
||||
# 2. decode geometry
|
||||
mesh_v_f, has_surface = extract_geometry(
|
||||
geometric_func=geometric_func,
|
||||
device=device,
|
||||
batch_size=bs,
|
||||
bounds=bounds,
|
||||
octree_depth=octree_depth,
|
||||
num_chunks=num_chunks,
|
||||
disable=not self.zero_rank
|
||||
)
|
||||
|
||||
# 3. decode texture
|
||||
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
|
||||
if not is_surface:
|
||||
outputs.append(None)
|
||||
continue
|
||||
|
||||
out = AlignedMeshOutput()
|
||||
out.mesh_v = mesh_v
|
||||
out.mesh_f = mesh_f
|
||||
out.surface = surface[i].cpu().numpy()
|
||||
out.image = image[i].cpu().numpy()
|
||||
if description is not None:
|
||||
out.text = description[i]
|
||||
out.shape_text_similarity = shape_text_similarity[i, i]
|
||||
out.shape_image_similarity = shape_image_similarity[i, i]
|
||||
|
||||
outputs.append(out)
|
||||
|
||||
return outputs
|
||||
|
||||
def latent2mesh(self,
|
||||
latents: torch.FloatTensor,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.1,
|
||||
octree_depth: int = 7,
|
||||
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
|
||||
|
||||
"""
|
||||
|
||||
Args:
|
||||
latents: [bs, num_latents, dim]
|
||||
bounds:
|
||||
octree_depth:
|
||||
num_chunks:
|
||||
|
||||
Returns:
|
||||
mesh_outputs (List[MeshOutput]): the mesh outputs list.
|
||||
|
||||
"""
|
||||
|
||||
outputs = []
|
||||
|
||||
geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
|
||||
|
||||
# 2. decode geometry
|
||||
device = latents.device
|
||||
mesh_v_f, has_surface = extract_geometry(
|
||||
geometric_func=geometric_func,
|
||||
device=device,
|
||||
batch_size=len(latents),
|
||||
bounds=bounds,
|
||||
octree_depth=octree_depth,
|
||||
num_chunks=num_chunks,
|
||||
disable=not self.zero_rank
|
||||
)
|
||||
|
||||
# 3. decode texture
|
||||
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
|
||||
if not is_surface:
|
||||
outputs.append(None)
|
||||
continue
|
||||
|
||||
out = Latent2MeshOutput()
|
||||
out.mesh_v = mesh_v
|
||||
out.mesh_f = mesh_f
|
||||
|
||||
outputs.append(out)
|
||||
|
||||
return outputs
|
||||
|
||||
114
primitive_anything/michelangelo/models/tsal/clip_asl_module.py
Executable file
114
primitive_anything/michelangelo/models/tsal/clip_asl_module.py
Executable file
@@ -0,0 +1,114 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
from transformers import CLIPModel
|
||||
|
||||
from .tsal_base import AlignedShapeAsLatentModule
|
||||
|
||||
|
||||
class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule):
|
||||
|
||||
def __init__(self, *,
|
||||
shape_model,
|
||||
projection_dim=768):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.shape_model = shape_model
|
||||
self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, projection_dim))
|
||||
nn.init.normal_(self.shape_projection, std=projection_dim ** -0.5)
|
||||
|
||||
def set_shape_model_only(self):
|
||||
self.clip_model = None
|
||||
|
||||
def encode_shape_embed(self, surface, return_latents: bool = False):
|
||||
"""
|
||||
|
||||
Args:
|
||||
surface (torch.FloatTensor): [bs, n, 3 + c]
|
||||
return_latents (bool):
|
||||
|
||||
Returns:
|
||||
x (torch.FloatTensor): [bs, projection_dim]
|
||||
shape_latents (torch.FloatTensor): [bs, m, d]
|
||||
"""
|
||||
|
||||
pc = surface[..., 0:3]
|
||||
feats = surface[..., 3:]
|
||||
|
||||
shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats)
|
||||
x = shape_embed @ self.shape_projection
|
||||
|
||||
if return_latents:
|
||||
return x, shape_latents
|
||||
else:
|
||||
return x
|
||||
|
||||
def encode_image_embed(self, image):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image (torch.FloatTensor): [bs, 3, h, w]
|
||||
|
||||
Returns:
|
||||
x (torch.FloatTensor): [bs, projection_dim]
|
||||
"""
|
||||
|
||||
x = self.clip_model.get_image_features(image)
|
||||
|
||||
return x
|
||||
|
||||
def encode_text_embed(self, text):
|
||||
x = self.clip_model.get_text_features(text)
|
||||
return x
|
||||
|
||||
def forward(self, surface, image, text):
|
||||
"""
|
||||
|
||||
Args:
|
||||
surface (torch.FloatTensor):
|
||||
image (torch.FloatTensor): [bs, 3, 224, 224]
|
||||
text (torch.LongTensor): [bs, num_templates, 77]
|
||||
|
||||
Returns:
|
||||
embed_outputs (dict): the embedding outputs, and it contains:
|
||||
- image_embed (torch.FloatTensor):
|
||||
- text_embed (torch.FloatTensor):
|
||||
- shape_embed (torch.FloatTensor):
|
||||
- logit_scale (float):
|
||||
"""
|
||||
|
||||
# # text embedding
|
||||
# text_embed_all = []
|
||||
# for i in range(text.shape[0]):
|
||||
# text_for_one_sample = text[i]
|
||||
# text_embed = self.encode_text_embed(text_for_one_sample)
|
||||
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
||||
# text_embed = text_embed.mean(dim=0)
|
||||
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
||||
# text_embed_all.append(text_embed)
|
||||
# text_embed_all = torch.stack(text_embed_all)
|
||||
|
||||
b = text.shape[0]
|
||||
text_tokens = rearrange(text, "b t l -> (b t) l")
|
||||
text_embed = self.encode_text_embed(text_tokens)
|
||||
text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
|
||||
text_embed = text_embed.mean(dim=1)
|
||||
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
|
||||
|
||||
# image embedding
|
||||
image_embed = self.encode_image_embed(image)
|
||||
|
||||
# shape embedding
|
||||
shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True)
|
||||
|
||||
embed_outputs = {
|
||||
"image_embed": image_embed,
|
||||
"text_embed": text_embed,
|
||||
"shape_embed": shape_embed,
|
||||
"logit_scale": self.clip_model.logit_scale.exp()
|
||||
}
|
||||
|
||||
return embed_outputs, shape_latents
|
||||
80
primitive_anything/michelangelo/models/tsal/inference_utils.py
Executable file
80
primitive_anything/michelangelo/models/tsal/inference_utils.py
Executable file
@@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from einops import repeat
|
||||
import numpy as np
|
||||
from typing import Callable, Tuple, List, Union, Optional
|
||||
from skimage import measure
|
||||
|
||||
from ...graphics.primitives import generate_dense_grid_points
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_geometry(geometric_func: Callable,
|
||||
device: torch.device,
|
||||
batch_size: int = 1,
|
||||
bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
|
||||
octree_depth: int = 7,
|
||||
num_chunks: int = 10000,
|
||||
disable: bool = True):
|
||||
"""
|
||||
|
||||
Args:
|
||||
geometric_func:
|
||||
device:
|
||||
bounds:
|
||||
octree_depth:
|
||||
batch_size:
|
||||
num_chunks:
|
||||
disable:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min = np.array(bounds[0:3])
|
||||
bbox_max = np.array(bounds[3:6])
|
||||
bbox_size = bbox_max - bbox_min
|
||||
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_depth=octree_depth,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.FloatTensor(xyz_samples)
|
||||
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
|
||||
desc="Implicit Function:", disable=disable, leave=False):
|
||||
queries = xyz_samples[start: start + num_chunks, :].to(device)
|
||||
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
||||
|
||||
logits = geometric_func(batch_queries)
|
||||
batch_logits.append(logits.cpu())
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy()
|
||||
|
||||
mesh_v_f = []
|
||||
has_surface = np.zeros((batch_size,), dtype=np.bool_)
|
||||
for i in range(batch_size):
|
||||
try:
|
||||
vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner")
|
||||
vertices = vertices / grid_size * bbox_size + bbox_min
|
||||
# vertices[:, [0, 1]] = vertices[:, [1, 0]]
|
||||
mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces)))
|
||||
has_surface[i] = True
|
||||
|
||||
except ValueError:
|
||||
mesh_v_f.append((None, None))
|
||||
has_surface[i] = False
|
||||
|
||||
except RuntimeError:
|
||||
mesh_v_f.append((None, None))
|
||||
has_surface[i] = False
|
||||
|
||||
return mesh_v_f, has_surface
|
||||
303
primitive_anything/michelangelo/models/tsal/loss.py
Executable file
303
primitive_anything/michelangelo/models/tsal/loss.py
Executable file
@@ -0,0 +1,303 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
from ..modules.distributions import DiagonalGaussianDistribution
|
||||
from ...utils.eval import compute_psnr
|
||||
from ...utils import misc
|
||||
|
||||
|
||||
class KLNearFar(nn.Module):
|
||||
def __init__(self,
|
||||
near_weight: float = 0.1,
|
||||
kl_weight: float = 1.0,
|
||||
num_near_samples: Optional[int] = None):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.near_weight = near_weight
|
||||
self.kl_weight = kl_weight
|
||||
self.num_near_samples = num_near_samples
|
||||
self.geo_criterion = nn.BCEWithLogitsLoss()
|
||||
|
||||
def forward(self,
|
||||
posteriors: Optional[DiagonalGaussianDistribution],
|
||||
logits: torch.FloatTensor,
|
||||
labels: torch.FloatTensor,
|
||||
split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
|
||||
|
||||
"""
|
||||
|
||||
Args:
|
||||
posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
|
||||
logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
|
||||
labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
|
||||
split (str):
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
loss (torch.Tensor): (,)
|
||||
log (dict):
|
||||
|
||||
"""
|
||||
|
||||
if self.num_near_samples is None:
|
||||
num_vol = logits.shape[1] // 2
|
||||
else:
|
||||
num_vol = logits.shape[1] - self.num_near_samples
|
||||
|
||||
vol_logits = logits[:, 0:num_vol]
|
||||
vol_labels = labels[:, 0:num_vol]
|
||||
|
||||
near_logits = logits[:, num_vol:]
|
||||
near_labels = labels[:, num_vol:]
|
||||
|
||||
# occupancy loss
|
||||
# vol_bce = self.geo_criterion(vol_logits, vol_labels)
|
||||
# near_bce = self.geo_criterion(near_logits, near_labels)
|
||||
vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
|
||||
near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
|
||||
|
||||
if posteriors is None:
|
||||
kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
|
||||
else:
|
||||
kl_loss = posteriors.kl(dims=(1, 2))
|
||||
kl_loss = torch.mean(kl_loss)
|
||||
|
||||
loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight
|
||||
|
||||
with torch.no_grad():
|
||||
preds = logits >= 0
|
||||
accuracy = (preds == labels).float()
|
||||
accuracy = accuracy.mean()
|
||||
pos_ratio = torch.mean(labels)
|
||||
|
||||
log = {
|
||||
"{}/total_loss".format(split): loss.clone().detach(),
|
||||
"{}/near".format(split): near_bce.detach(),
|
||||
"{}/far".format(split): vol_bce.detach(),
|
||||
"{}/kl".format(split): kl_loss.detach(),
|
||||
"{}/accuracy".format(split): accuracy,
|
||||
"{}/pos_ratio".format(split): pos_ratio
|
||||
}
|
||||
|
||||
if posteriors is not None:
|
||||
log[f"{split}/mean"] = posteriors.mean.mean().detach()
|
||||
log[f"{split}/std_mean"] = posteriors.std.mean().detach()
|
||||
log[f"{split}/std_max"] = posteriors.std.max().detach()
|
||||
|
||||
return loss, log
|
||||
|
||||
|
||||
class KLNearFarColor(nn.Module):
|
||||
def __init__(self,
|
||||
near_weight: float = 0.1,
|
||||
kl_weight: float = 1.0,
|
||||
color_weight: float = 1.0,
|
||||
color_criterion: str = "mse",
|
||||
num_near_samples: Optional[int] = None):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.color_weight = color_weight
|
||||
self.near_weight = near_weight
|
||||
self.kl_weight = kl_weight
|
||||
self.num_near_samples = num_near_samples
|
||||
|
||||
if color_criterion == "mse":
|
||||
self.color_criterion = nn.MSELoss()
|
||||
|
||||
elif color_criterion == "l1":
|
||||
self.color_criterion = nn.L1Loss()
|
||||
|
||||
else:
|
||||
raise ValueError(f"{color_criterion} must be [`mse`, `l1`].")
|
||||
|
||||
self.geo_criterion = nn.BCEWithLogitsLoss()
|
||||
|
||||
def forward(self,
|
||||
posteriors: Optional[DiagonalGaussianDistribution],
|
||||
logits: torch.FloatTensor,
|
||||
labels: torch.FloatTensor,
|
||||
pred_colors: torch.FloatTensor,
|
||||
gt_colors: torch.FloatTensor,
|
||||
split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
|
||||
|
||||
"""
|
||||
|
||||
Args:
|
||||
posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
|
||||
logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
|
||||
labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
|
||||
pred_colors (torch.FloatTensor): [B, M, 3]
|
||||
gt_colors (torch.FloatTensor): [B, M, 3]
|
||||
split (str):
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
loss (torch.Tensor): (,)
|
||||
log (dict):
|
||||
|
||||
"""
|
||||
|
||||
if self.num_near_samples is None:
|
||||
num_vol = logits.shape[1] // 2
|
||||
else:
|
||||
num_vol = logits.shape[1] - self.num_near_samples
|
||||
|
||||
vol_logits = logits[:, 0:num_vol]
|
||||
vol_labels = labels[:, 0:num_vol]
|
||||
|
||||
near_logits = logits[:, num_vol:]
|
||||
near_labels = labels[:, num_vol:]
|
||||
|
||||
# occupancy loss
|
||||
# vol_bce = self.geo_criterion(vol_logits, vol_labels)
|
||||
# near_bce = self.geo_criterion(near_logits, near_labels)
|
||||
vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
|
||||
near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
|
||||
|
||||
# surface color loss
|
||||
color = self.color_criterion(pred_colors, gt_colors)
|
||||
|
||||
if posteriors is None:
|
||||
kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device)
|
||||
else:
|
||||
kl_loss = posteriors.kl(dims=(1, 2))
|
||||
kl_loss = torch.mean(kl_loss)
|
||||
|
||||
loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight
|
||||
|
||||
with torch.no_grad():
|
||||
preds = logits >= 0
|
||||
accuracy = (preds == labels).float()
|
||||
accuracy = accuracy.mean()
|
||||
psnr = compute_psnr(pred_colors, gt_colors)
|
||||
|
||||
log = {
|
||||
"{}/total_loss".format(split): loss.clone().detach(),
|
||||
"{}/near".format(split): near_bce.detach(),
|
||||
"{}/far".format(split): vol_bce.detach(),
|
||||
"{}/color".format(split): color.detach(),
|
||||
"{}/kl".format(split): kl_loss.detach(),
|
||||
"{}/psnr".format(split): psnr.detach(),
|
||||
"{}/accuracy".format(split): accuracy
|
||||
}
|
||||
|
||||
return loss, log
|
||||
|
||||
|
||||
class ContrastKLNearFar(nn.Module):
|
||||
def __init__(self,
|
||||
contrast_weight: float = 1.0,
|
||||
near_weight: float = 0.1,
|
||||
kl_weight: float = 1.0,
|
||||
num_near_samples: Optional[int] = None):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.labels = None
|
||||
self.last_local_batch_size = None
|
||||
|
||||
self.contrast_weight = contrast_weight
|
||||
self.near_weight = near_weight
|
||||
self.kl_weight = kl_weight
|
||||
self.num_near_samples = num_near_samples
|
||||
self.geo_criterion = nn.BCEWithLogitsLoss()
|
||||
|
||||
def forward(self,
|
||||
shape_embed: torch.FloatTensor,
|
||||
text_embed: torch.FloatTensor,
|
||||
image_embed: torch.FloatTensor,
|
||||
logit_scale: torch.FloatTensor,
|
||||
posteriors: Optional[DiagonalGaussianDistribution],
|
||||
shape_logits: torch.FloatTensor,
|
||||
shape_labels: torch.FloatTensor,
|
||||
split: Optional[str] = "train", **kwargs):
|
||||
|
||||
local_batch_size = shape_embed.size(0)
|
||||
|
||||
if local_batch_size != self.last_local_batch_size:
|
||||
self.labels = local_batch_size * misc.get_rank() + torch.arange(
|
||||
local_batch_size, device=shape_embed.device
|
||||
).long()
|
||||
self.last_local_batch_size = local_batch_size
|
||||
|
||||
# normalized features
|
||||
shape_embed = F.normalize(shape_embed, dim=-1, p=2)
|
||||
text_embed = F.normalize(text_embed, dim=-1, p=2)
|
||||
image_embed = F.normalize(image_embed, dim=-1, p=2)
|
||||
|
||||
# gather features from all GPUs
|
||||
shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch(
|
||||
[shape_embed, text_embed, image_embed]
|
||||
)
|
||||
|
||||
# cosine similarity as logits
|
||||
logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t()
|
||||
logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t()
|
||||
logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t()
|
||||
logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t()
|
||||
contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) +
|
||||
F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \
|
||||
(F.cross_entropy(logits_per_shape_image, self.labels) +
|
||||
F.cross_entropy(logits_per_image_shape, self.labels)) / 2
|
||||
|
||||
# shape reconstruction
|
||||
if self.num_near_samples is None:
|
||||
num_vol = shape_logits.shape[1] // 2
|
||||
else:
|
||||
num_vol = shape_logits.shape[1] - self.num_near_samples
|
||||
|
||||
vol_logits = shape_logits[:, 0:num_vol]
|
||||
vol_labels = shape_labels[:, 0:num_vol]
|
||||
|
||||
near_logits = shape_logits[:, num_vol:]
|
||||
near_labels = shape_labels[:, num_vol:]
|
||||
|
||||
# occupancy loss
|
||||
vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
|
||||
near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
|
||||
|
||||
if posteriors is None:
|
||||
kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
|
||||
else:
|
||||
kl_loss = posteriors.kl(dims=(1, 2))
|
||||
kl_loss = torch.mean(kl_loss)
|
||||
|
||||
loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight
|
||||
|
||||
# compute accuracy
|
||||
with torch.no_grad():
|
||||
pred = torch.argmax(logits_per_shape_text, dim=-1)
|
||||
correct = pred.eq(self.labels).sum()
|
||||
shape_text_acc = 100 * correct / local_batch_size
|
||||
|
||||
pred = torch.argmax(logits_per_shape_image, dim=-1)
|
||||
correct = pred.eq(self.labels).sum()
|
||||
shape_image_acc = 100 * correct / local_batch_size
|
||||
|
||||
preds = shape_logits >= 0
|
||||
accuracy = (preds == shape_labels).float()
|
||||
accuracy = accuracy.mean()
|
||||
|
||||
log = {
|
||||
"{}/contrast".format(split): contrast_loss.clone().detach(),
|
||||
"{}/near".format(split): near_bce.detach(),
|
||||
"{}/far".format(split): vol_bce.detach(),
|
||||
"{}/kl".format(split): kl_loss.detach(),
|
||||
"{}/shape_text_acc".format(split): shape_text_acc,
|
||||
"{}/shape_image_acc".format(split): shape_image_acc,
|
||||
"{}/total_loss".format(split): loss.clone().detach(),
|
||||
"{}/accuracy".format(split): accuracy,
|
||||
}
|
||||
|
||||
if posteriors is not None:
|
||||
log[f"{split}/mean"] = posteriors.mean.mean().detach()
|
||||
log[f"{split}/std_mean"] = posteriors.std.mean().detach()
|
||||
log[f"{split}/std_max"] = posteriors.std.max().detach()
|
||||
|
||||
return loss, log
|
||||
423
primitive_anything/michelangelo/models/tsal/sal_perceiver.py
Executable file
423
primitive_anything/michelangelo/models/tsal/sal_perceiver.py
Executable file
@@ -0,0 +1,423 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
from einops import repeat
|
||||
import math
|
||||
|
||||
from ..modules import checkpoint
|
||||
from ..modules.embedder import FourierEmbedder
|
||||
from ..modules.distributions import DiagonalGaussianDistribution
|
||||
from ..modules.transformer_blocks import (
|
||||
ResidualCrossAttentionBlock,
|
||||
Transformer
|
||||
)
|
||||
|
||||
from .tsal_base import ShapeAsLatentModule
|
||||
|
||||
|
||||
class CrossAttentionEncoder(nn.Module):
|
||||
|
||||
def __init__(self, *,
|
||||
device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype],
|
||||
num_latents: int,
|
||||
fourier_embedder: FourierEmbedder,
|
||||
point_feats: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
layers: int,
|
||||
init_scale: float = 0.25,
|
||||
qkv_bias: bool = True,
|
||||
flash: bool = False,
|
||||
use_ln_post: bool = False,
|
||||
use_checkpoint: bool = False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.num_latents = num_latents
|
||||
|
||||
self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
|
||||
|
||||
self.fourier_embedder = fourier_embedder
|
||||
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
|
||||
self.cross_attn = ResidualCrossAttentionBlock(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
width=width,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
)
|
||||
|
||||
self.self_attn = Transformer(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_ctx=num_latents,
|
||||
width=width,
|
||||
layers=layers,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
use_checkpoint=False
|
||||
)
|
||||
|
||||
if use_ln_post:
|
||||
self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
|
||||
else:
|
||||
self.ln_post = None
|
||||
|
||||
def _forward(self, pc, feats):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, C]
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
bs = pc.shape[0]
|
||||
|
||||
data = self.fourier_embedder(pc)
|
||||
if feats is not None:
|
||||
data = torch.cat([data, feats], dim=-1)
|
||||
data = self.input_proj(data)
|
||||
|
||||
query = repeat(self.query, "m c -> b m c", b=bs)
|
||||
latents = self.cross_attn(query, data)
|
||||
latents = self.self_attn(latents)
|
||||
|
||||
if self.ln_post is not None:
|
||||
latents = self.ln_post(latents)
|
||||
|
||||
return latents, pc
|
||||
|
||||
def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, C]
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
|
||||
return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
|
||||
|
||||
|
||||
class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
def __init__(self, *,
|
||||
device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype],
|
||||
num_latents: int,
|
||||
out_channels: int,
|
||||
fourier_embedder: FourierEmbedder,
|
||||
width: int,
|
||||
heads: int,
|
||||
init_scale: float = 0.25,
|
||||
qkv_bias: bool = True,
|
||||
flash: bool = False,
|
||||
use_checkpoint: bool = False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.fourier_embedder = fourier_embedder
|
||||
|
||||
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
|
||||
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_data=num_latents,
|
||||
width=width,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash
|
||||
)
|
||||
|
||||
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
||||
self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
|
||||
|
||||
def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
|
||||
queries = self.query_proj(self.fourier_embedder(queries))
|
||||
x = self.cross_attn_decoder(queries, latents)
|
||||
x = self.ln_post(x)
|
||||
x = self.output_proj(x)
|
||||
return x
|
||||
|
||||
def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
|
||||
return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)
|
||||
|
||||
|
||||
class ShapeAsLatentPerceiver(ShapeAsLatentModule):
|
||||
def __init__(self, *,
|
||||
device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype],
|
||||
num_latents: int,
|
||||
point_feats: int = 0,
|
||||
embed_dim: int = 0,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
init_scale: float = 0.25,
|
||||
qkv_bias: bool = True,
|
||||
flash: bool = False,
|
||||
use_ln_post: bool = False,
|
||||
use_checkpoint: bool = False):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
self.num_latents = num_latents
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
init_scale = init_scale * math.sqrt(1.0 / width)
|
||||
self.encoder = CrossAttentionEncoder(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
num_latents=num_latents,
|
||||
point_feats=point_feats,
|
||||
width=width,
|
||||
heads=heads,
|
||||
layers=num_encoder_layers,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
use_ln_post=use_ln_post,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
if embed_dim > 0:
|
||||
# VAE embed
|
||||
self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
|
||||
self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype)
|
||||
self.latent_shape = (num_latents, embed_dim)
|
||||
else:
|
||||
self.latent_shape = (num_latents, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
n_ctx=num_latents,
|
||||
width=width,
|
||||
layers=num_decoder_layers,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
|
||||
# geometry decoder
|
||||
self.geo_decoder = CrossAttentionDecoder(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
out_channels=1,
|
||||
num_latents=num_latents,
|
||||
width=width,
|
||||
heads=heads,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
|
||||
def encode(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: Optional[torch.FloatTensor] = None,
|
||||
sample_posterior: bool = True):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, C]
|
||||
sample_posterior (bool):
|
||||
|
||||
Returns:
|
||||
latents (torch.FloatTensor)
|
||||
center_pos (torch.FloatTensor or None):
|
||||
posterior (DiagonalGaussianDistribution or None):
|
||||
"""
|
||||
|
||||
latents, center_pos = self.encoder(pc, feats)
|
||||
|
||||
posterior = None
|
||||
if self.embed_dim > 0:
|
||||
moments = self.pre_kl(latents)
|
||||
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
|
||||
|
||||
if sample_posterior:
|
||||
latents = posterior.sample()
|
||||
else:
|
||||
latents = posterior.mode()
|
||||
|
||||
return latents, center_pos, posterior
|
||||
|
||||
def decode(self, latents: torch.FloatTensor):
|
||||
latents = self.post_kl(latents)
|
||||
return self.transformer(latents)
|
||||
|
||||
def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
|
||||
logits = self.geo_decoder(queries, latents).squeeze(-1)
|
||||
return logits
|
||||
|
||||
def forward(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: torch.FloatTensor,
|
||||
volume_queries: torch.FloatTensor,
|
||||
sample_posterior: bool = True):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, C]
|
||||
volume_queries (torch.FloatTensor): [B, P, 3]
|
||||
sample_posterior (bool):
|
||||
|
||||
Returns:
|
||||
logits (torch.FloatTensor): [B, P]
|
||||
center_pos (torch.FloatTensor): [B, M, 3]
|
||||
posterior (DiagonalGaussianDistribution or None).
|
||||
|
||||
"""
|
||||
|
||||
latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
|
||||
|
||||
latents = self.decode(latents)
|
||||
logits = self.query_geometry(volume_queries, latents)
|
||||
|
||||
return logits, center_pos, posterior
|
||||
|
||||
|
||||
class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver):
|
||||
|
||||
def __init__(self, *,
|
||||
device: Optional[torch.device],
|
||||
dtype: Optional[torch.dtype],
|
||||
num_latents: int,
|
||||
point_feats: int = 0,
|
||||
embed_dim: int = 0,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_encoder_layers: int,
|
||||
num_decoder_layers: int,
|
||||
init_scale: float = 0.25,
|
||||
qkv_bias: bool = True,
|
||||
flash: bool = False,
|
||||
use_ln_post: bool = False,
|
||||
use_checkpoint: bool = False):
|
||||
|
||||
super().__init__(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_latents=1 + num_latents,
|
||||
point_feats=point_feats,
|
||||
embed_dim=embed_dim,
|
||||
num_freqs=num_freqs,
|
||||
include_pi=include_pi,
|
||||
width=width,
|
||||
heads=heads,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
init_scale=init_scale,
|
||||
qkv_bias=qkv_bias,
|
||||
flash=flash,
|
||||
use_ln_post=use_ln_post,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
|
||||
self.width = width
|
||||
|
||||
def encode(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: Optional[torch.FloatTensor] = None,
|
||||
sample_posterior: bool = True):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, c]
|
||||
sample_posterior (bool):
|
||||
|
||||
Returns:
|
||||
shape_embed (torch.FloatTensor)
|
||||
kl_embed (torch.FloatTensor):
|
||||
posterior (DiagonalGaussianDistribution or None):
|
||||
"""
|
||||
|
||||
shape_embed, latents = self.encode_latents(pc, feats)
|
||||
kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior)
|
||||
|
||||
return shape_embed, kl_embed, posterior
|
||||
|
||||
def encode_latents(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: Optional[torch.FloatTensor] = None):
|
||||
|
||||
x, _ = self.encoder(pc, feats)
|
||||
|
||||
shape_embed = x[:, 0]
|
||||
latents = x[:, 1:]
|
||||
|
||||
return shape_embed, latents
|
||||
|
||||
def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
|
||||
posterior = None
|
||||
if self.embed_dim > 0:
|
||||
moments = self.pre_kl(latents)
|
||||
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
|
||||
|
||||
if sample_posterior:
|
||||
kl_embed = posterior.sample()
|
||||
else:
|
||||
kl_embed = posterior.mode()
|
||||
else:
|
||||
kl_embed = latents
|
||||
|
||||
return kl_embed, posterior
|
||||
|
||||
def forward(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: torch.FloatTensor,
|
||||
volume_queries: torch.FloatTensor,
|
||||
sample_posterior: bool = True):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc (torch.FloatTensor): [B, N, 3]
|
||||
feats (torch.FloatTensor or None): [B, N, C]
|
||||
volume_queries (torch.FloatTensor): [B, P, 3]
|
||||
sample_posterior (bool):
|
||||
|
||||
Returns:
|
||||
shape_embed (torch.FloatTensor): [B, projection_dim]
|
||||
logits (torch.FloatTensor): [B, M]
|
||||
posterior (DiagonalGaussianDistribution or None).
|
||||
|
||||
"""
|
||||
|
||||
shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
|
||||
|
||||
latents = self.decode(kl_embed)
|
||||
logits = self.query_geometry(volume_queries, latents)
|
||||
|
||||
return shape_embed, logits, posterior
|
||||
290
primitive_anything/michelangelo/models/tsal/sal_pl_module.py
Executable file
290
primitive_anything/michelangelo/models/tsal/sal_pl_module.py
Executable file
@@ -0,0 +1,290 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
from omegaconf import DictConfig
|
||||
|
||||
import torch
|
||||
from torch.optim import lr_scheduler
|
||||
import pytorch_lightning as pl
|
||||
from typing import Union
|
||||
from functools import partial
|
||||
|
||||
from ...utils import instantiate_from_config
|
||||
|
||||
from .inference_utils import extract_geometry
|
||||
from .tsal_base import (
|
||||
ShapeAsLatentModule,
|
||||
Latent2MeshOutput,
|
||||
Point2MeshOutput
|
||||
)
|
||||
|
||||
|
||||
class ShapeAsLatentPLModule(pl.LightningModule):
|
||||
|
||||
def __init__(self, *,
|
||||
module_cfg,
|
||||
loss_cfg,
|
||||
optimizer_cfg: Optional[DictConfig] = None,
|
||||
ckpt_path: Optional[str] = None,
|
||||
ignore_keys: Union[Tuple[str], List[str]] = ()):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.sal: ShapeAsLatentModule = instantiate_from_config(module_cfg, device=None, dtype=None)
|
||||
|
||||
self.loss = instantiate_from_config(loss_cfg)
|
||||
|
||||
self.optimizer_cfg = optimizer_cfg
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
self.save_hyperparameters()
|
||||
|
||||
@property
|
||||
def latent_shape(self):
|
||||
return self.sal.latent_shape
|
||||
|
||||
@property
|
||||
def zero_rank(self):
|
||||
if self._trainer:
|
||||
zero_rank = self.trainer.local_rank == 0
|
||||
else:
|
||||
zero_rank = True
|
||||
|
||||
return zero_rank
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=()):
|
||||
state_dict = torch.load(path, map_location="cpu")["state_dict"]
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del state_dict[k]
|
||||
|
||||
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def configure_optimizers(self) -> Tuple[List, List]:
|
||||
lr = self.learning_rate
|
||||
|
||||
# optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)]
|
||||
# optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
|
||||
|
||||
if self.optimizer_cfg is None:
|
||||
optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
|
||||
schedulers = []
|
||||
else:
|
||||
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=self.sal.parameters())
|
||||
scheduler_func = instantiate_from_config(
|
||||
self.optimizer_cfg.scheduler,
|
||||
max_decay_steps=self.trainer.max_steps,
|
||||
lr_max=lr
|
||||
)
|
||||
scheduler = {
|
||||
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1
|
||||
}
|
||||
optimizers = [optimizer]
|
||||
schedulers = [scheduler]
|
||||
|
||||
return optimizers, schedulers
|
||||
|
||||
def forward(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: torch.FloatTensor,
|
||||
volume_queries: torch.FloatTensor):
|
||||
|
||||
logits, center_pos, posterior = self.sal(pc, feats, volume_queries)
|
||||
|
||||
return posterior, logits
|
||||
|
||||
def encode(self, surface: torch.FloatTensor, sample_posterior=True):
|
||||
|
||||
pc = surface[..., 0:3]
|
||||
feats = surface[..., 3:6]
|
||||
|
||||
latents, center_pos, posterior = self.sal.encode(
|
||||
pc=pc, feats=feats, sample_posterior=sample_posterior
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def decode(self,
|
||||
z_q,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.1,
|
||||
octree_depth: int = 7,
|
||||
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
|
||||
|
||||
latents = self.sal.decode(z_q) # latents: [bs, num_latents, dim]
|
||||
outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
|
||||
|
||||
return outputs
|
||||
|
||||
def training_step(self, batch: Dict[str, torch.FloatTensor],
|
||||
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
batch (dict): the batch sample, and it contains:
|
||||
- surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
|
||||
- geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
|
||||
|
||||
batch_idx (int):
|
||||
|
||||
optimizer_idx (int):
|
||||
|
||||
Returns:
|
||||
loss (torch.FloatTensor):
|
||||
|
||||
"""
|
||||
|
||||
pc = batch["surface"][..., 0:3]
|
||||
feats = batch["surface"][..., 3:]
|
||||
|
||||
volume_queries = batch["geo_points"][..., 0:3]
|
||||
volume_labels = batch["geo_points"][..., -1]
|
||||
|
||||
posterior, logits = self(
|
||||
pc=pc, feats=feats, volume_queries=volume_queries
|
||||
)
|
||||
aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="train")
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
|
||||
sync_dist=False, rank_zero_only=True)
|
||||
|
||||
return aeloss
|
||||
|
||||
def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
|
||||
|
||||
pc = batch["surface"][..., 0:3]
|
||||
feats = batch["surface"][..., 3:]
|
||||
|
||||
volume_queries = batch["geo_points"][..., 0:3]
|
||||
volume_labels = batch["geo_points"][..., -1]
|
||||
|
||||
posterior, logits = self(
|
||||
pc=pc, feats=feats, volume_queries=volume_queries,
|
||||
)
|
||||
aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="val")
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
|
||||
sync_dist=False, rank_zero_only=True)
|
||||
|
||||
return aeloss
|
||||
|
||||
def point2mesh(self,
|
||||
pc: torch.FloatTensor,
|
||||
feats: torch.FloatTensor,
|
||||
bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
|
||||
octree_depth: int = 7,
|
||||
num_chunks: int = 10000) -> List[Point2MeshOutput]:
|
||||
|
||||
"""
|
||||
|
||||
Args:
|
||||
pc:
|
||||
feats:
|
||||
bounds:
|
||||
octree_depth:
|
||||
num_chunks:
|
||||
|
||||
Returns:
|
||||
mesh_outputs (List[MeshOutput]): the mesh outputs list.
|
||||
|
||||
"""
|
||||
|
||||
outputs = []
|
||||
|
||||
device = pc.device
|
||||
bs = pc.shape[0]
|
||||
|
||||
# 1. point encoder + latents transformer
|
||||
latents, center_pos, posterior = self.sal.encode(pc, feats)
|
||||
latents = self.sal.decode(latents) # latents: [bs, num_latents, dim]
|
||||
|
||||
geometric_func = partial(self.sal.query_geometry, latents=latents)
|
||||
|
||||
# 2. decode geometry
|
||||
mesh_v_f, has_surface = extract_geometry(
|
||||
geometric_func=geometric_func,
|
||||
device=device,
|
||||
batch_size=bs,
|
||||
bounds=bounds,
|
||||
octree_depth=octree_depth,
|
||||
num_chunks=num_chunks,
|
||||
disable=not self.zero_rank
|
||||
)
|
||||
|
||||
# 3. decode texture
|
||||
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
|
||||
if not is_surface:
|
||||
outputs.append(None)
|
||||
continue
|
||||
|
||||
out = Point2MeshOutput()
|
||||
out.mesh_v = mesh_v
|
||||
out.mesh_f = mesh_f
|
||||
out.pc = torch.cat([pc[i], feats[i]], dim=-1).cpu().numpy()
|
||||
|
||||
if center_pos is not None:
|
||||
out.center = center_pos[i].cpu().numpy()
|
||||
|
||||
outputs.append(out)
|
||||
|
||||
return outputs
|
||||
|
||||
def latent2mesh(self,
|
||||
latents: torch.FloatTensor,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.1,
|
||||
octree_depth: int = 7,
|
||||
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
|
||||
|
||||
"""
|
||||
|
||||
Args:
|
||||
latents: [bs, num_latents, dim]
|
||||
bounds:
|
||||
octree_depth:
|
||||
num_chunks:
|
||||
|
||||
Returns:
|
||||
mesh_outputs (List[MeshOutput]): the mesh outputs list.
|
||||
|
||||
"""
|
||||
|
||||
outputs = []
|
||||
|
||||
geometric_func = partial(self.sal.query_geometry, latents=latents)
|
||||
|
||||
# 2. decode geometry
|
||||
device = latents.device
|
||||
mesh_v_f, has_surface = extract_geometry(
|
||||
geometric_func=geometric_func,
|
||||
device=device,
|
||||
batch_size=len(latents),
|
||||
bounds=bounds,
|
||||
octree_depth=octree_depth,
|
||||
num_chunks=num_chunks,
|
||||
disable=not self.zero_rank
|
||||
)
|
||||
|
||||
# 3. decode texture
|
||||
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
|
||||
if not is_surface:
|
||||
outputs.append(None)
|
||||
continue
|
||||
|
||||
out = Latent2MeshOutput()
|
||||
out.mesh_v = mesh_v
|
||||
out.mesh_f = mesh_f
|
||||
|
||||
outputs.append(out)
|
||||
|
||||
return outputs
|
||||
121
primitive_anything/michelangelo/models/tsal/tsal_base.py
Executable file
121
primitive_anything/michelangelo/models/tsal/tsal_base.py
Executable file
@@ -0,0 +1,121 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, List, Optional
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
class Point2MeshOutput(object):
|
||||
def __init__(self):
|
||||
self.mesh_v = None
|
||||
self.mesh_f = None
|
||||
self.center = None
|
||||
self.pc = None
|
||||
|
||||
|
||||
class Latent2MeshOutput(object):
|
||||
|
||||
def __init__(self):
|
||||
self.mesh_v = None
|
||||
self.mesh_f = None
|
||||
|
||||
|
||||
class AlignedMeshOutput(object):
|
||||
|
||||
def __init__(self):
|
||||
self.mesh_v = None
|
||||
self.mesh_f = None
|
||||
self.surface = None
|
||||
self.image = None
|
||||
self.text: Optional[str] = None
|
||||
self.shape_text_similarity: Optional[float] = None
|
||||
self.shape_image_similarity: Optional[float] = None
|
||||
|
||||
|
||||
class ShapeAsLatentPLModule(pl.LightningModule):
|
||||
latent_shape: Tuple[int]
|
||||
|
||||
def encode(self, surface, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, z_q, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
|
||||
raise NotImplementedError
|
||||
|
||||
def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ShapeAsLatentModule(nn.Module):
|
||||
latent_shape: Tuple[int, int]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def query_geometry(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AlignedShapeAsLatentPLModule(pl.LightningModule):
|
||||
latent_shape: Tuple[int]
|
||||
|
||||
def set_shape_model_only(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(self, surface, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, z_q, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
|
||||
raise NotImplementedError
|
||||
|
||||
def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AlignedShapeAsLatentModule(nn.Module):
|
||||
shape_model: ShapeAsLatentModule
|
||||
latent_shape: Tuple[int, int]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def set_shape_model_only(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_image_embed(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_text_embed(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_shape_embed(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TexturedShapeAsLatentModule(nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def query_geometry(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def query_color(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user