This commit is contained in:
hyz317
2025-05-07 16:51:22 +08:00
commit 87c3ed5e40
54 changed files with 8014 additions and 0 deletions

View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,373 @@
# -*- coding: utf-8 -*-
from typing import List, Tuple, Dict, Optional
from omegaconf import DictConfig
import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
import pytorch_lightning as pl
from typing import Union
from functools import partial
from ...utils import instantiate_from_config
from .inference_utils import extract_geometry
from .tsal_base import (
AlignedShapeAsLatentModule,
ShapeAsLatentModule,
Latent2MeshOutput,
AlignedMeshOutput
)
class AlignedShapeAsLatentPLModule(pl.LightningModule):
def __init__(self, *,
shape_module_cfg,
aligned_module_cfg,
loss_cfg,
optimizer_cfg: Optional[DictConfig] = None,
ckpt_path: Optional[str] = None,
ignore_keys: Union[Tuple[str], List[str]] = ()):
super().__init__()
shape_model: ShapeAsLatentModule = instantiate_from_config(
shape_module_cfg, device=None, dtype=None
)
self.model: AlignedShapeAsLatentModule = instantiate_from_config(
aligned_module_cfg, shape_model=shape_model
)
self.loss = instantiate_from_config(loss_cfg)
self.optimizer_cfg = optimizer_cfg
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.save_hyperparameters()
def set_shape_model_only(self):
self.model.set_shape_model_only()
@property
def latent_shape(self):
return self.model.shape_model.latent_shape
@property
def zero_rank(self):
if self._trainer:
zero_rank = self.trainer.local_rank == 0
else:
zero_rank = True
return zero_rank
def init_from_ckpt(self, path, ignore_keys=()):
state_dict = torch.load(path, map_location="cpu")["state_dict"]
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def configure_optimizers(self) -> Tuple[List, List]:
lr = self.learning_rate
trainable_parameters = list(self.model.parameters())
if self.optimizer_cfg is None:
optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
schedulers = []
else:
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
scheduler_func = instantiate_from_config(
self.optimizer_cfg.scheduler,
max_decay_steps=self.trainer.max_steps,
lr_max=lr
)
scheduler = {
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
"interval": "step",
"frequency": 1
}
optimizers = [optimizer]
schedulers = [scheduler]
return optimizers, schedulers
def forward(self,
surface: torch.FloatTensor,
image: torch.FloatTensor,
text: torch.FloatTensor,
volume_queries: torch.FloatTensor):
"""
Args:
surface (torch.FloatTensor):
image (torch.FloatTensor):
text (torch.FloatTensor):
volume_queries (torch.FloatTensor):
Returns:
"""
embed_outputs, shape_z = self.model(surface, image, text)
shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
latents = self.model.shape_model.decode(shape_zq)
logits = self.model.shape_model.query_geometry(volume_queries, latents)
return embed_outputs, logits, posterior
def encode(self, surface: torch.FloatTensor, sample_posterior=True):
pc = surface[..., 0:3]
feats = surface[..., 3:6]
shape_embed, shape_zq, posterior = self.model.shape_model.encode(
pc=pc, feats=feats, sample_posterior=sample_posterior
)
return shape_zq
def encode_latents(self, surface: torch.FloatTensor):
pc = surface[..., 0:3]
feats = surface[..., 3:6]
shape_embed, shape_latents = self.model.shape_model.encode_latents(
pc=pc, feats=feats
)
shape_embed = shape_embed.unsqueeze(1)
assert shape_embed.shape[1] == 1 and shape_latents.shape[1] == 256
cat_latents = torch.cat([shape_embed, shape_latents], dim=1)
return cat_latents
def to_shape_latents(self, latents):
shape_zq, posterior = self.model.shape_model.encode_kl_embed(latents, sample_posterior = False)
return self.model.shape_model.decode(shape_zq)
def decode(self,
z_q,
bounds: Union[Tuple[float], List[float], float] = 1.1,
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim]
outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
return outputs
def training_step(self, batch: Dict[str, torch.FloatTensor],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
- image (torch.FloatTensor): [bs, 3, 224, 224]
- text (torch.FloatTensor): [bs, num_templates, 77]
- geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
surface = batch["surface"]
image = batch["image"]
text = batch["text"]
volume_queries = batch["geo_points"][..., 0:3]
shape_labels = batch["geo_points"][..., -1]
embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
aeloss, log_dict_ae = self.loss(
**embed_outputs,
posteriors=posteriors,
shape_logits=shape_logits,
shape_labels=shape_labels,
split="train"
)
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
sync_dist=False, rank_zero_only=True)
return aeloss
def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
surface = batch["surface"]
image = batch["image"]
text = batch["text"]
volume_queries = batch["geo_points"][..., 0:3]
shape_labels = batch["geo_points"][..., -1]
embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
aeloss, log_dict_ae = self.loss(
**embed_outputs,
posteriors=posteriors,
shape_logits=shape_logits,
shape_labels=shape_labels,
split="val"
)
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
sync_dist=False, rank_zero_only=True)
return aeloss
def visual_alignment(self,
surface: torch.FloatTensor,
image: torch.FloatTensor,
text: torch.FloatTensor,
description: Optional[List[str]] = None,
bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
octree_depth: int = 7,
num_chunks: int = 10000) -> List[AlignedMeshOutput]:
"""
Args:
surface:
image:
text:
description:
bounds:
octree_depth:
num_chunks:
Returns:
mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list.
"""
outputs = []
device = surface.device
bs = surface.shape[0]
embed_outputs, shape_z = self.model(surface, image, text)
# calculate the similarity
image_embed = embed_outputs["image_embed"]
text_embed = embed_outputs["text_embed"]
shape_embed = embed_outputs["shape_embed"]
# normalized features
shape_embed = F.normalize(shape_embed, dim=-1, p=2)
text_embed = F.normalize(text_embed, dim=-1, p=2)
image_embed = F.normalize(image_embed, dim=-1, p=2)
# B x B
shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1)
# B x B
shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1)
# shape reconstruction
shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
latents = self.model.shape_model.decode(shape_zq)
geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
# 2. decode geometry
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=device,
batch_size=bs,
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
disable=not self.zero_rank
)
# 3. decode texture
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
if not is_surface:
outputs.append(None)
continue
out = AlignedMeshOutput()
out.mesh_v = mesh_v
out.mesh_f = mesh_f
out.surface = surface[i].cpu().numpy()
out.image = image[i].cpu().numpy()
if description is not None:
out.text = description[i]
out.shape_text_similarity = shape_text_similarity[i, i]
out.shape_image_similarity = shape_image_similarity[i, i]
outputs.append(out)
return outputs
def latent2mesh(self,
latents: torch.FloatTensor,
bounds: Union[Tuple[float], List[float], float] = 1.1,
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
"""
Args:
latents: [bs, num_latents, dim]
bounds:
octree_depth:
num_chunks:
Returns:
mesh_outputs (List[MeshOutput]): the mesh outputs list.
"""
outputs = []
geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
# 2. decode geometry
device = latents.device
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=device,
batch_size=len(latents),
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
disable=not self.zero_rank
)
# 3. decode texture
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
if not is_surface:
outputs.append(None)
continue
out = Latent2MeshOutput()
out.mesh_v = mesh_v
out.mesh_f = mesh_f
outputs.append(out)
return outputs

View File

@@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
import torch
from torch import nn
from einops import rearrange
from transformers import CLIPModel
from .tsal_base import AlignedShapeAsLatentModule
class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule):
def __init__(self, *,
shape_model,
projection_dim=768):
super().__init__()
self.shape_model = shape_model
self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, projection_dim))
nn.init.normal_(self.shape_projection, std=projection_dim ** -0.5)
def set_shape_model_only(self):
self.clip_model = None
def encode_shape_embed(self, surface, return_latents: bool = False):
"""
Args:
surface (torch.FloatTensor): [bs, n, 3 + c]
return_latents (bool):
Returns:
x (torch.FloatTensor): [bs, projection_dim]
shape_latents (torch.FloatTensor): [bs, m, d]
"""
pc = surface[..., 0:3]
feats = surface[..., 3:]
shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats)
x = shape_embed @ self.shape_projection
if return_latents:
return x, shape_latents
else:
return x
def encode_image_embed(self, image):
"""
Args:
image (torch.FloatTensor): [bs, 3, h, w]
Returns:
x (torch.FloatTensor): [bs, projection_dim]
"""
x = self.clip_model.get_image_features(image)
return x
def encode_text_embed(self, text):
x = self.clip_model.get_text_features(text)
return x
def forward(self, surface, image, text):
"""
Args:
surface (torch.FloatTensor):
image (torch.FloatTensor): [bs, 3, 224, 224]
text (torch.LongTensor): [bs, num_templates, 77]
Returns:
embed_outputs (dict): the embedding outputs, and it contains:
- image_embed (torch.FloatTensor):
- text_embed (torch.FloatTensor):
- shape_embed (torch.FloatTensor):
- logit_scale (float):
"""
# # text embedding
# text_embed_all = []
# for i in range(text.shape[0]):
# text_for_one_sample = text[i]
# text_embed = self.encode_text_embed(text_for_one_sample)
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
# text_embed = text_embed.mean(dim=0)
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
# text_embed_all.append(text_embed)
# text_embed_all = torch.stack(text_embed_all)
b = text.shape[0]
text_tokens = rearrange(text, "b t l -> (b t) l")
text_embed = self.encode_text_embed(text_tokens)
text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
text_embed = text_embed.mean(dim=1)
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
# image embedding
image_embed = self.encode_image_embed(image)
# shape embedding
shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True)
embed_outputs = {
"image_embed": image_embed,
"text_embed": text_embed,
"shape_embed": shape_embed,
"logit_scale": self.clip_model.logit_scale.exp()
}
return embed_outputs, shape_latents

View File

@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
import torch
from tqdm import tqdm
from einops import repeat
import numpy as np
from typing import Callable, Tuple, List, Union, Optional
from skimage import measure
from ...graphics.primitives import generate_dense_grid_points
@torch.no_grad()
def extract_geometry(geometric_func: Callable,
device: torch.device,
batch_size: int = 1,
bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
octree_depth: int = 7,
num_chunks: int = 10000,
disable: bool = True):
"""
Args:
geometric_func:
device:
bounds:
octree_depth:
batch_size:
num_chunks:
disable:
Returns:
"""
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min = np.array(bounds[0:3])
bbox_max = np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_depth=octree_depth,
indexing="ij"
)
xyz_samples = torch.FloatTensor(xyz_samples)
batch_logits = []
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
desc="Implicit Function:", disable=disable, leave=False):
queries = xyz_samples[start: start + num_chunks, :].to(device)
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
logits = geometric_func(batch_queries)
batch_logits.append(logits.cpu())
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy()
mesh_v_f = []
has_surface = np.zeros((batch_size,), dtype=np.bool_)
for i in range(batch_size):
try:
vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner")
vertices = vertices / grid_size * bbox_size + bbox_min
# vertices[:, [0, 1]] = vertices[:, [1, 0]]
mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces)))
has_surface[i] = True
except ValueError:
mesh_v_f.append((None, None))
has_surface[i] = False
except RuntimeError:
mesh_v_f.append((None, None))
has_surface[i] = False
return mesh_v_f, has_surface

View File

@@ -0,0 +1,303 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict
from ..modules.distributions import DiagonalGaussianDistribution
from ...utils.eval import compute_psnr
from ...utils import misc
class KLNearFar(nn.Module):
def __init__(self,
near_weight: float = 0.1,
kl_weight: float = 1.0,
num_near_samples: Optional[int] = None):
super().__init__()
self.near_weight = near_weight
self.kl_weight = kl_weight
self.num_near_samples = num_near_samples
self.geo_criterion = nn.BCEWithLogitsLoss()
def forward(self,
posteriors: Optional[DiagonalGaussianDistribution],
logits: torch.FloatTensor,
labels: torch.FloatTensor,
split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
"""
Args:
posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
split (str):
**kwargs:
Returns:
loss (torch.Tensor): (,)
log (dict):
"""
if self.num_near_samples is None:
num_vol = logits.shape[1] // 2
else:
num_vol = logits.shape[1] - self.num_near_samples
vol_logits = logits[:, 0:num_vol]
vol_labels = labels[:, 0:num_vol]
near_logits = logits[:, num_vol:]
near_labels = labels[:, num_vol:]
# occupancy loss
# vol_bce = self.geo_criterion(vol_logits, vol_labels)
# near_bce = self.geo_criterion(near_logits, near_labels)
vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
if posteriors is None:
kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
else:
kl_loss = posteriors.kl(dims=(1, 2))
kl_loss = torch.mean(kl_loss)
loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight
with torch.no_grad():
preds = logits >= 0
accuracy = (preds == labels).float()
accuracy = accuracy.mean()
pos_ratio = torch.mean(labels)
log = {
"{}/total_loss".format(split): loss.clone().detach(),
"{}/near".format(split): near_bce.detach(),
"{}/far".format(split): vol_bce.detach(),
"{}/kl".format(split): kl_loss.detach(),
"{}/accuracy".format(split): accuracy,
"{}/pos_ratio".format(split): pos_ratio
}
if posteriors is not None:
log[f"{split}/mean"] = posteriors.mean.mean().detach()
log[f"{split}/std_mean"] = posteriors.std.mean().detach()
log[f"{split}/std_max"] = posteriors.std.max().detach()
return loss, log
class KLNearFarColor(nn.Module):
def __init__(self,
near_weight: float = 0.1,
kl_weight: float = 1.0,
color_weight: float = 1.0,
color_criterion: str = "mse",
num_near_samples: Optional[int] = None):
super().__init__()
self.color_weight = color_weight
self.near_weight = near_weight
self.kl_weight = kl_weight
self.num_near_samples = num_near_samples
if color_criterion == "mse":
self.color_criterion = nn.MSELoss()
elif color_criterion == "l1":
self.color_criterion = nn.L1Loss()
else:
raise ValueError(f"{color_criterion} must be [`mse`, `l1`].")
self.geo_criterion = nn.BCEWithLogitsLoss()
def forward(self,
posteriors: Optional[DiagonalGaussianDistribution],
logits: torch.FloatTensor,
labels: torch.FloatTensor,
pred_colors: torch.FloatTensor,
gt_colors: torch.FloatTensor,
split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
"""
Args:
posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
pred_colors (torch.FloatTensor): [B, M, 3]
gt_colors (torch.FloatTensor): [B, M, 3]
split (str):
**kwargs:
Returns:
loss (torch.Tensor): (,)
log (dict):
"""
if self.num_near_samples is None:
num_vol = logits.shape[1] // 2
else:
num_vol = logits.shape[1] - self.num_near_samples
vol_logits = logits[:, 0:num_vol]
vol_labels = labels[:, 0:num_vol]
near_logits = logits[:, num_vol:]
near_labels = labels[:, num_vol:]
# occupancy loss
# vol_bce = self.geo_criterion(vol_logits, vol_labels)
# near_bce = self.geo_criterion(near_logits, near_labels)
vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
# surface color loss
color = self.color_criterion(pred_colors, gt_colors)
if posteriors is None:
kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device)
else:
kl_loss = posteriors.kl(dims=(1, 2))
kl_loss = torch.mean(kl_loss)
loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight
with torch.no_grad():
preds = logits >= 0
accuracy = (preds == labels).float()
accuracy = accuracy.mean()
psnr = compute_psnr(pred_colors, gt_colors)
log = {
"{}/total_loss".format(split): loss.clone().detach(),
"{}/near".format(split): near_bce.detach(),
"{}/far".format(split): vol_bce.detach(),
"{}/color".format(split): color.detach(),
"{}/kl".format(split): kl_loss.detach(),
"{}/psnr".format(split): psnr.detach(),
"{}/accuracy".format(split): accuracy
}
return loss, log
class ContrastKLNearFar(nn.Module):
def __init__(self,
contrast_weight: float = 1.0,
near_weight: float = 0.1,
kl_weight: float = 1.0,
num_near_samples: Optional[int] = None):
super().__init__()
self.labels = None
self.last_local_batch_size = None
self.contrast_weight = contrast_weight
self.near_weight = near_weight
self.kl_weight = kl_weight
self.num_near_samples = num_near_samples
self.geo_criterion = nn.BCEWithLogitsLoss()
def forward(self,
shape_embed: torch.FloatTensor,
text_embed: torch.FloatTensor,
image_embed: torch.FloatTensor,
logit_scale: torch.FloatTensor,
posteriors: Optional[DiagonalGaussianDistribution],
shape_logits: torch.FloatTensor,
shape_labels: torch.FloatTensor,
split: Optional[str] = "train", **kwargs):
local_batch_size = shape_embed.size(0)
if local_batch_size != self.last_local_batch_size:
self.labels = local_batch_size * misc.get_rank() + torch.arange(
local_batch_size, device=shape_embed.device
).long()
self.last_local_batch_size = local_batch_size
# normalized features
shape_embed = F.normalize(shape_embed, dim=-1, p=2)
text_embed = F.normalize(text_embed, dim=-1, p=2)
image_embed = F.normalize(image_embed, dim=-1, p=2)
# gather features from all GPUs
shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch(
[shape_embed, text_embed, image_embed]
)
# cosine similarity as logits
logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t()
logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t()
logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t()
logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t()
contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) +
F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \
(F.cross_entropy(logits_per_shape_image, self.labels) +
F.cross_entropy(logits_per_image_shape, self.labels)) / 2
# shape reconstruction
if self.num_near_samples is None:
num_vol = shape_logits.shape[1] // 2
else:
num_vol = shape_logits.shape[1] - self.num_near_samples
vol_logits = shape_logits[:, 0:num_vol]
vol_labels = shape_labels[:, 0:num_vol]
near_logits = shape_logits[:, num_vol:]
near_labels = shape_labels[:, num_vol:]
# occupancy loss
vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
if posteriors is None:
kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
else:
kl_loss = posteriors.kl(dims=(1, 2))
kl_loss = torch.mean(kl_loss)
loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight
# compute accuracy
with torch.no_grad():
pred = torch.argmax(logits_per_shape_text, dim=-1)
correct = pred.eq(self.labels).sum()
shape_text_acc = 100 * correct / local_batch_size
pred = torch.argmax(logits_per_shape_image, dim=-1)
correct = pred.eq(self.labels).sum()
shape_image_acc = 100 * correct / local_batch_size
preds = shape_logits >= 0
accuracy = (preds == shape_labels).float()
accuracy = accuracy.mean()
log = {
"{}/contrast".format(split): contrast_loss.clone().detach(),
"{}/near".format(split): near_bce.detach(),
"{}/far".format(split): vol_bce.detach(),
"{}/kl".format(split): kl_loss.detach(),
"{}/shape_text_acc".format(split): shape_text_acc,
"{}/shape_image_acc".format(split): shape_image_acc,
"{}/total_loss".format(split): loss.clone().detach(),
"{}/accuracy".format(split): accuracy,
}
if posteriors is not None:
log[f"{split}/mean"] = posteriors.mean.mean().detach()
log[f"{split}/std_mean"] = posteriors.std.mean().detach()
log[f"{split}/std_max"] = posteriors.std.max().detach()
return loss, log

View File

@@ -0,0 +1,423 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from typing import Optional
from einops import repeat
import math
from ..modules import checkpoint
from ..modules.embedder import FourierEmbedder
from ..modules.distributions import DiagonalGaussianDistribution
from ..modules.transformer_blocks import (
ResidualCrossAttentionBlock,
Transformer
)
from .tsal_base import ShapeAsLatentModule
class CrossAttentionEncoder(nn.Module):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
num_latents: int,
fourier_embedder: FourierEmbedder,
point_feats: int,
width: int,
heads: int,
layers: int,
init_scale: float = 0.25,
qkv_bias: bool = True,
flash: bool = False,
use_ln_post: bool = False,
use_checkpoint: bool = False):
super().__init__()
self.use_checkpoint = use_checkpoint
self.num_latents = num_latents
self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
self.fourier_embedder = fourier_embedder
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
self.cross_attn = ResidualCrossAttentionBlock(
device=device,
dtype=dtype,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
)
self.self_attn = Transformer(
device=device,
dtype=dtype,
n_ctx=num_latents,
width=width,
layers=layers,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
use_checkpoint=False
)
if use_ln_post:
self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
else:
self.ln_post = None
def _forward(self, pc, feats):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, C]
Returns:
"""
bs = pc.shape[0]
data = self.fourier_embedder(pc)
if feats is not None:
data = torch.cat([data, feats], dim=-1)
data = self.input_proj(data)
query = repeat(self.query, "m c -> b m c", b=bs)
latents = self.cross_attn(query, data)
latents = self.self_attn(latents)
if self.ln_post is not None:
latents = self.ln_post(latents)
return latents, pc
def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, C]
Returns:
dict
"""
return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
class CrossAttentionDecoder(nn.Module):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
num_latents: int,
out_channels: int,
fourier_embedder: FourierEmbedder,
width: int,
heads: int,
init_scale: float = 0.25,
qkv_bias: bool = True,
flash: bool = False,
use_checkpoint: bool = False):
super().__init__()
self.use_checkpoint = use_checkpoint
self.fourier_embedder = fourier_embedder
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
self.cross_attn_decoder = ResidualCrossAttentionBlock(
device=device,
dtype=dtype,
n_data=num_latents,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
queries = self.query_proj(self.fourier_embedder(queries))
x = self.cross_attn_decoder(queries, latents)
x = self.ln_post(x)
x = self.output_proj(x)
return x
def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)
class ShapeAsLatentPerceiver(ShapeAsLatentModule):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
num_latents: int,
point_feats: int = 0,
embed_dim: int = 0,
num_freqs: int = 8,
include_pi: bool = True,
width: int,
heads: int,
num_encoder_layers: int,
num_decoder_layers: int,
init_scale: float = 0.25,
qkv_bias: bool = True,
flash: bool = False,
use_ln_post: bool = False,
use_checkpoint: bool = False):
super().__init__()
self.use_checkpoint = use_checkpoint
self.num_latents = num_latents
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
init_scale = init_scale * math.sqrt(1.0 / width)
self.encoder = CrossAttentionEncoder(
device=device,
dtype=dtype,
fourier_embedder=self.fourier_embedder,
num_latents=num_latents,
point_feats=point_feats,
width=width,
heads=heads,
layers=num_encoder_layers,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
use_ln_post=use_ln_post,
use_checkpoint=use_checkpoint
)
self.embed_dim = embed_dim
if embed_dim > 0:
# VAE embed
self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype)
self.latent_shape = (num_latents, embed_dim)
else:
self.latent_shape = (num_latents, width)
self.transformer = Transformer(
device=device,
dtype=dtype,
n_ctx=num_latents,
width=width,
layers=num_decoder_layers,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
use_checkpoint=use_checkpoint
)
# geometry decoder
self.geo_decoder = CrossAttentionDecoder(
device=device,
dtype=dtype,
fourier_embedder=self.fourier_embedder,
out_channels=1,
num_latents=num_latents,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
use_checkpoint=use_checkpoint
)
def encode(self,
pc: torch.FloatTensor,
feats: Optional[torch.FloatTensor] = None,
sample_posterior: bool = True):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, C]
sample_posterior (bool):
Returns:
latents (torch.FloatTensor)
center_pos (torch.FloatTensor or None):
posterior (DiagonalGaussianDistribution or None):
"""
latents, center_pos = self.encoder(pc, feats)
posterior = None
if self.embed_dim > 0:
moments = self.pre_kl(latents)
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
if sample_posterior:
latents = posterior.sample()
else:
latents = posterior.mode()
return latents, center_pos, posterior
def decode(self, latents: torch.FloatTensor):
latents = self.post_kl(latents)
return self.transformer(latents)
def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
logits = self.geo_decoder(queries, latents).squeeze(-1)
return logits
def forward(self,
pc: torch.FloatTensor,
feats: torch.FloatTensor,
volume_queries: torch.FloatTensor,
sample_posterior: bool = True):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, C]
volume_queries (torch.FloatTensor): [B, P, 3]
sample_posterior (bool):
Returns:
logits (torch.FloatTensor): [B, P]
center_pos (torch.FloatTensor): [B, M, 3]
posterior (DiagonalGaussianDistribution or None).
"""
latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
latents = self.decode(latents)
logits = self.query_geometry(volume_queries, latents)
return logits, center_pos, posterior
class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
num_latents: int,
point_feats: int = 0,
embed_dim: int = 0,
num_freqs: int = 8,
include_pi: bool = True,
width: int,
heads: int,
num_encoder_layers: int,
num_decoder_layers: int,
init_scale: float = 0.25,
qkv_bias: bool = True,
flash: bool = False,
use_ln_post: bool = False,
use_checkpoint: bool = False):
super().__init__(
device=device,
dtype=dtype,
num_latents=1 + num_latents,
point_feats=point_feats,
embed_dim=embed_dim,
num_freqs=num_freqs,
include_pi=include_pi,
width=width,
heads=heads,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
use_ln_post=use_ln_post,
use_checkpoint=use_checkpoint
)
self.width = width
def encode(self,
pc: torch.FloatTensor,
feats: Optional[torch.FloatTensor] = None,
sample_posterior: bool = True):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, c]
sample_posterior (bool):
Returns:
shape_embed (torch.FloatTensor)
kl_embed (torch.FloatTensor):
posterior (DiagonalGaussianDistribution or None):
"""
shape_embed, latents = self.encode_latents(pc, feats)
kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior)
return shape_embed, kl_embed, posterior
def encode_latents(self,
pc: torch.FloatTensor,
feats: Optional[torch.FloatTensor] = None):
x, _ = self.encoder(pc, feats)
shape_embed = x[:, 0]
latents = x[:, 1:]
return shape_embed, latents
def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
posterior = None
if self.embed_dim > 0:
moments = self.pre_kl(latents)
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
if sample_posterior:
kl_embed = posterior.sample()
else:
kl_embed = posterior.mode()
else:
kl_embed = latents
return kl_embed, posterior
def forward(self,
pc: torch.FloatTensor,
feats: torch.FloatTensor,
volume_queries: torch.FloatTensor,
sample_posterior: bool = True):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, C]
volume_queries (torch.FloatTensor): [B, P, 3]
sample_posterior (bool):
Returns:
shape_embed (torch.FloatTensor): [B, projection_dim]
logits (torch.FloatTensor): [B, M]
posterior (DiagonalGaussianDistribution or None).
"""
shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
latents = self.decode(kl_embed)
logits = self.query_geometry(volume_queries, latents)
return shape_embed, logits, posterior

View File

@@ -0,0 +1,290 @@
# -*- coding: utf-8 -*-
from typing import List, Tuple, Dict, Optional
from omegaconf import DictConfig
import torch
from torch.optim import lr_scheduler
import pytorch_lightning as pl
from typing import Union
from functools import partial
from ...utils import instantiate_from_config
from .inference_utils import extract_geometry
from .tsal_base import (
ShapeAsLatentModule,
Latent2MeshOutput,
Point2MeshOutput
)
class ShapeAsLatentPLModule(pl.LightningModule):
def __init__(self, *,
module_cfg,
loss_cfg,
optimizer_cfg: Optional[DictConfig] = None,
ckpt_path: Optional[str] = None,
ignore_keys: Union[Tuple[str], List[str]] = ()):
super().__init__()
self.sal: ShapeAsLatentModule = instantiate_from_config(module_cfg, device=None, dtype=None)
self.loss = instantiate_from_config(loss_cfg)
self.optimizer_cfg = optimizer_cfg
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.save_hyperparameters()
@property
def latent_shape(self):
return self.sal.latent_shape
@property
def zero_rank(self):
if self._trainer:
zero_rank = self.trainer.local_rank == 0
else:
zero_rank = True
return zero_rank
def init_from_ckpt(self, path, ignore_keys=()):
state_dict = torch.load(path, map_location="cpu")["state_dict"]
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def configure_optimizers(self) -> Tuple[List, List]:
lr = self.learning_rate
# optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)]
# optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
if self.optimizer_cfg is None:
optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
schedulers = []
else:
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=self.sal.parameters())
scheduler_func = instantiate_from_config(
self.optimizer_cfg.scheduler,
max_decay_steps=self.trainer.max_steps,
lr_max=lr
)
scheduler = {
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
"interval": "step",
"frequency": 1
}
optimizers = [optimizer]
schedulers = [scheduler]
return optimizers, schedulers
def forward(self,
pc: torch.FloatTensor,
feats: torch.FloatTensor,
volume_queries: torch.FloatTensor):
logits, center_pos, posterior = self.sal(pc, feats, volume_queries)
return posterior, logits
def encode(self, surface: torch.FloatTensor, sample_posterior=True):
pc = surface[..., 0:3]
feats = surface[..., 3:6]
latents, center_pos, posterior = self.sal.encode(
pc=pc, feats=feats, sample_posterior=sample_posterior
)
return latents
def decode(self,
z_q,
bounds: Union[Tuple[float], List[float], float] = 1.1,
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
latents = self.sal.decode(z_q) # latents: [bs, num_latents, dim]
outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
return outputs
def training_step(self, batch: Dict[str, torch.FloatTensor],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
- geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
pc = batch["surface"][..., 0:3]
feats = batch["surface"][..., 3:]
volume_queries = batch["geo_points"][..., 0:3]
volume_labels = batch["geo_points"][..., -1]
posterior, logits = self(
pc=pc, feats=feats, volume_queries=volume_queries
)
aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="train")
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
sync_dist=False, rank_zero_only=True)
return aeloss
def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
pc = batch["surface"][..., 0:3]
feats = batch["surface"][..., 3:]
volume_queries = batch["geo_points"][..., 0:3]
volume_labels = batch["geo_points"][..., -1]
posterior, logits = self(
pc=pc, feats=feats, volume_queries=volume_queries,
)
aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="val")
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
sync_dist=False, rank_zero_only=True)
return aeloss
def point2mesh(self,
pc: torch.FloatTensor,
feats: torch.FloatTensor,
bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Point2MeshOutput]:
"""
Args:
pc:
feats:
bounds:
octree_depth:
num_chunks:
Returns:
mesh_outputs (List[MeshOutput]): the mesh outputs list.
"""
outputs = []
device = pc.device
bs = pc.shape[0]
# 1. point encoder + latents transformer
latents, center_pos, posterior = self.sal.encode(pc, feats)
latents = self.sal.decode(latents) # latents: [bs, num_latents, dim]
geometric_func = partial(self.sal.query_geometry, latents=latents)
# 2. decode geometry
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=device,
batch_size=bs,
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
disable=not self.zero_rank
)
# 3. decode texture
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
if not is_surface:
outputs.append(None)
continue
out = Point2MeshOutput()
out.mesh_v = mesh_v
out.mesh_f = mesh_f
out.pc = torch.cat([pc[i], feats[i]], dim=-1).cpu().numpy()
if center_pos is not None:
out.center = center_pos[i].cpu().numpy()
outputs.append(out)
return outputs
def latent2mesh(self,
latents: torch.FloatTensor,
bounds: Union[Tuple[float], List[float], float] = 1.1,
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
"""
Args:
latents: [bs, num_latents, dim]
bounds:
octree_depth:
num_chunks:
Returns:
mesh_outputs (List[MeshOutput]): the mesh outputs list.
"""
outputs = []
geometric_func = partial(self.sal.query_geometry, latents=latents)
# 2. decode geometry
device = latents.device
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=device,
batch_size=len(latents),
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
disable=not self.zero_rank
)
# 3. decode texture
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
if not is_surface:
outputs.append(None)
continue
out = Latent2MeshOutput()
out.mesh_v = mesh_v
out.mesh_f = mesh_f
outputs.append(out)
return outputs

View File

@@ -0,0 +1,121 @@
# -*- coding: utf-8 -*-
import torch.nn as nn
from typing import Tuple, List, Optional
import pytorch_lightning as pl
class Point2MeshOutput(object):
def __init__(self):
self.mesh_v = None
self.mesh_f = None
self.center = None
self.pc = None
class Latent2MeshOutput(object):
def __init__(self):
self.mesh_v = None
self.mesh_f = None
class AlignedMeshOutput(object):
def __init__(self):
self.mesh_v = None
self.mesh_f = None
self.surface = None
self.image = None
self.text: Optional[str] = None
self.shape_text_similarity: Optional[float] = None
self.shape_image_similarity: Optional[float] = None
class ShapeAsLatentPLModule(pl.LightningModule):
latent_shape: Tuple[int]
def encode(self, surface, *args, **kwargs):
raise NotImplementedError
def decode(self, z_q, *args, **kwargs):
raise NotImplementedError
def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
raise NotImplementedError
def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
raise NotImplementedError
class ShapeAsLatentModule(nn.Module):
latent_shape: Tuple[int, int]
def __init__(self, *args, **kwargs):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
def decode(self, *args, **kwargs):
raise NotImplementedError
def query_geometry(self, *args, **kwargs):
raise NotImplementedError
class AlignedShapeAsLatentPLModule(pl.LightningModule):
latent_shape: Tuple[int]
def set_shape_model_only(self):
raise NotImplementedError
def encode(self, surface, *args, **kwargs):
raise NotImplementedError
def decode(self, z_q, *args, **kwargs):
raise NotImplementedError
def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
raise NotImplementedError
def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
raise NotImplementedError
class AlignedShapeAsLatentModule(nn.Module):
shape_model: ShapeAsLatentModule
latent_shape: Tuple[int, int]
def __init__(self, *args, **kwargs):
super().__init__()
def set_shape_model_only(self):
raise NotImplementedError
def encode_image_embed(self, *args, **kwargs):
raise NotImplementedError
def encode_text_embed(self, *args, **kwargs):
raise NotImplementedError
def encode_shape_embed(self, *args, **kwargs):
raise NotImplementedError
class TexturedShapeAsLatentModule(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
def decode(self, *args, **kwargs):
raise NotImplementedError
def query_geometry(self, *args, **kwargs):
raise NotImplementedError
def query_color(self, *args, **kwargs):
raise NotImplementedError