This commit is contained in:
hyz317 2025-05-07 16:51:22 +08:00
commit 87c3ed5e40
54 changed files with 8014 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
**/__pycache__/
ckpt
data
results

112
README.md Normal file
View File

@ -0,0 +1,112 @@
# PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer
<a href="https://primitiveanything.github.io/"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Github&color=blue&logo=github-pages"></a>&ensp;<a href="#"><img src="https://img.shields.io/badge/ArXiv-250x.xxxxx-brightgreen"></a>&ensp;<a href="https://huggingface.co/hyz317/PrimitiveAnything"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a>&ensp;<a href="https://huggingface.co/spaces/hyz317/PrimitiveAnything"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a>
<img src="./assets/teaser.jpg" width="100%">
## 🔥 Updates
**[2025/05/07]** test dataset, code, pretrained checkpoints and Gradio demo are released!
## 🔍 Table of Contents
- [⚙️ Deployment](#deployment)
- [🖥️ Run PrimitiveAnything](#run-pa)
- [📝 Citation](#citation)
<a name="deployment"></a>
## ⚙️ Deployment
Set up a Python environment and install the required packages:
```bash
conda create -n primitiveanything python=3.9 -y
conda activate primitiveanything
# Install torch, torchvision based on your machine configuration
pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118
# Install other dependencies
pip install -r requirements.txt
```
Then download data and pretrained weights:
1. **Our Model Weights**:
Download from our 🤗 Hugging Face repository ([download here](https://huggingface.co/hyz317/PrimitiveAnything)) and place them in `./ckpt/`.
2. **Michelangelos Point Cloud Encoder**:
Download weights from [Michelangelos Hugging Face repo](https://huggingface.co/Maikou/Michelangelo/tree/main/checkpoints/aligned_shape_latents) and save them to `./ckpt/`.
3. **Demo and test data**:
Download from this [Google Drive link](https://drive.google.com/file/d/1FZZjk0OvzETD5j4OODEghS_YppcS_ZbM/view?usp=sharing), then decompress the files into `./data/`.
After downloading and organizing the files, your project directory should look like this:
```
- data/
├── basic_shapes_norm/
├── basic_shapes_norm_pc10000/
├── demo_glb/ # Demo files in GLB format
└── test_pc/ # Test point cloud data
- ckpt/
├── mesh-transformer.ckpt.60.pt # Our model checkpoint
└── shapevae-256.ckpt # Michelangelo ShapeVAE checkpoint
```
<a name="run-pa"></a>
## 🖥️ Run PrimitiveAnything
### Demo
```bash
python demo.py --input ./data/demo_glb --log_path ./results/demo
```
**Notes:**
- `--input` accepts either:
- Any standard 3D file (GLB, OBJ, etc.)
- A directory containing multiple 3D files
- For optimal results with fine structures, we automatically apply marching cubes and dilation operations (which differs from testing and evaluation). This prevents quality degradation in thin areas.
### Testing and Evaluation
```bash
# Autoregressive generation
python infer.py
# Sample point clouds from predictions
python sample.py
# Calculate evaluation metrics
python eval.py
```
<a name="citation"></a>
## 📝 Citation
If you find our work useful, please kindly cite:
```
@article{ye2025primitiveanything,
title={PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer},
author={Ye, Jingwen and He, Yuze and Zhou, Yanning and Zhu, Yiqin and Xiao, Kaiwen and Liu, Yong-Jin and Yang, Wei and Han, Xiao},
journal={arXiv preprint arXiv:250x.xxxxx},
year={2025}
}
```

BIN
assets/teaser.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

52
configs/infer.yml Executable file
View File

@ -0,0 +1,52 @@
dataset:
name: base
pc_dir: ./data/test_pc
bs_dir: data/basic_shapes_norm
max_length: 144
range_scale: [0, 1]
range_rotation: [-180, 180]
range_translation: [-1, 1]
rotation_type: euler
pc_format: pn
model:
attn_depth: 6
attn_heads: 6
bin_smooth_blur_sigma: -1
bs_pc_dir: data/basic_shapes_norm_pc10000
coarse_pre_gateloop_depth: 3
continuous_range_rotation:
- -181
- 181
continuous_range_scale:
- 0
- 1
continuous_range_translation:
- -1
- 1
dim: 768
dim_rotation_embed: 16
dim_scale_embed: 16
dim_translation_embed: 16
dim_type_embed: 48
dropout: 0.0
embed_order: ctrs
gateloop_use_heinsen: false
loss_weight:
eos: 1.0
reconstruction: 1.0
rotation: 1.0
scale: 1.0
translation: 1.0
type: 1.0
max_primitive_len: 144
name: discrete
num_discrete_rotation: 181
num_discrete_scale: 128
num_discrete_translation: 128
num_type: 3
shape_cond_with_cat: true
shape_cond_with_cross_attn: false
shape_cond_with_film: false
shape_condition_dim: 768
shape_condition_len: 77
shape_condition_model_type: michelangelo

342
demo.py Executable file
View File

@ -0,0 +1,342 @@
import os
import time
import glob
import json
import yaml
import torch
import trimesh
import argparse
import mesh2sdf.core
import numpy as np
import skimage.measure
import seaborn as sns
from scipy.spatial.transform import Rotation
from mesh_to_sdf import get_surface_point_cloud
from accelerate.utils import set_seed
from accelerate import Accelerator
from primitive_anything.utils import path_mkdir, count_parameters
from primitive_anything.utils.logger import print_log
os.environ['PYOPENGL_PLATFORM'] = 'egl'
def parse_args():
parser = argparse.ArgumentParser(description='Process 3D model files')
parser.add_argument(
'--input',
type=str,
default='./data/demo_glb/',
help='Input file or directory path (default: ./data/demo_glb/)'
)
parser.add_argument(
'--log_path',
type=str,
default='./results/demo',
help='Output directory path (default: results/demo)'
)
return parser.parse_args()
def get_input_files(input_path):
if os.path.isfile(input_path):
return [input_path]
elif os.path.isdir(input_path):
return glob.glob(os.path.join(input_path, '*'))
else:
raise ValueError(f"Input path {input_path} is neither a file nor a directory")
args = parse_args()
# Get input files (keeping your original variable name)
input_3ds = get_input_files(args.input)
if not input_3ds:
raise FileNotFoundError(f"No files found at input path: {args.input}")
# Create output directory (keeping your original variable name)
LOG_PATH = args.log_path
os.makedirs(LOG_PATH, exist_ok=True)
print(f"Found {len(input_3ds)} input files")
print(f"Output directory: {LOG_PATH}")
CODE_SHAPE = {
0: 'SM_GR_BS_CubeBevel_001.ply',
1: 'SM_GR_BS_SphereSharp_001.ply',
2: 'SM_GR_BS_CylinderSharp_001.ply',
}
shapename_map = {
'SM_GR_BS_CubeBevel_001.ply': 1101002001034001,
'SM_GR_BS_SphereSharp_001.ply': 1101002001034010,
'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002,
}
#### config
bs_dir = 'data/basic_shapes_norm'
config_path = './configs/infer.yml'
AR_checkpoint_path = './ckpt/mesh-transformer.ckpt.60.pt'
temperature= 0.0
#### init model
mesh_bs = {}
for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')):
bs_name = os.path.basename(bs_path)
bs = trimesh.load(bs_path)
bs.visual.uv = np.clip(bs.visual.uv, 0, 1)
bs.visual = bs.visual.to_color()
mesh_bs[bs_name] = bs
def create_model(cfg_model):
kwargs = cfg_model
name = kwargs.pop('name')
model = get_model(name)(**kwargs)
print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs))
return model
from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete
def get_model(name):
return {
'discrete': PrimitiveTransformerDiscrete,
}[name]
with open(config_path, mode='r') as fp:
AR_train_cfg = yaml.load(fp, Loader=yaml.FullLoader)
AR_checkpoint = torch.load(AR_checkpoint_path)
transformer = create_model(AR_train_cfg['model'])
transformer.load_state_dict(AR_checkpoint)
device = torch.device('cuda')
accelerator = Accelerator(
mixed_precision='fp16',
)
transformer = accelerator.prepare(transformer)
transformer.eval()
transformer.bs_pc = transformer.bs_pc.cuda()
transformer.rotation_matrix_align_coord = transformer.rotation_matrix_align_coord.cuda()
print('model loaded to device')
def sample_surface_points(mesh, number_of_points=500000, surface_point_method='scan', sign_method='normal',
scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False,
return_surface_pc_normals=False, normalized=False):
sample_start = time.time()
if surface_point_method == 'sample' and sign_method == 'depth':
print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.")
sign_method = 'normal'
surface_start = time.time()
bound_radius = 1 if normalized else None
surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution,
sample_point_count,
calculate_normals=sign_method == 'normal' or return_gradients)
surface_end = time.time()
print('surface point cloud time cost :', surface_end - surface_start)
normal_start = time.time()
if return_surface_pc_normals:
rng = np.random.default_rng()
assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0]
indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True)
points = surface_point_cloud.points[indices]
normals = surface_point_cloud.normals[indices]
surface_points = np.concatenate([points, normals], axis=-1)
else:
surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True)
normal_end = time.time()
print('normal time cost :', normal_end - normal_start)
sample_end = time.time()
print('sample surface point time cost :', sample_end - sample_start)
return surface_points
def normalize_vertices(vertices, scale=0.9):
bbmin, bbmax = vertices.min(0), vertices.max(0)
center = (bbmin + bbmax) * 0.5
scale = 2.0 * scale / (bbmax - bbmin).max()
vertices = (vertices - center) * scale
return vertices, center, scale
def export_to_watertight(normalized_mesh, octree_depth: int = 7):
"""
Convert the non-watertight mesh to watertight.
Args:
input_path (str): normalized path
octree_depth (int):
Returns:
mesh(trimesh.Trimesh): watertight mesh
"""
size = 2 ** octree_depth
level = 2 / size
scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices)
sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size)
vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level)
# watertight mesh
vertices = vertices / size * 2 - 1 # -1 to 1
vertices = vertices / to_orig_scale + to_orig_center
mesh = trimesh.Trimesh(vertices, faces, normals=normals)
return mesh
def process_mesh_to_surface_pc(mesh_list, marching_cubes=False, dilated_offset=0.0, sample_num=10000):
# mesh_list : list of trimesh
pc_normal_list = []
return_mesh_list = []
for mesh in mesh_list:
if marching_cubes:
mesh = export_to_watertight(mesh)
print("MC over!")
if dilated_offset > 0:
new_vertices = mesh.vertices + mesh.vertex_normals * dilated_offset
mesh.vertices = new_vertices
print("dilate over!")
mesh.merge_vertices()
mesh.update_faces(mesh.unique_faces())
mesh.fix_normals()
return_mesh_list.append(mesh)
pc_normal = np.asarray(sample_surface_points(mesh, sample_num, return_surface_pc_normals=True))
pc_normal_list.append(pc_normal)
print("process mesh success")
return pc_normal_list, return_mesh_list
#### utils
def euler_to_quat(euler):
return Rotation.from_euler('XYZ', euler, degrees=True).as_quat()
def SRT_quat_to_matrix(scale, quat, translation):
rotation_matrix = Rotation.from_quat(quat).as_matrix()
transform_matrix = np.eye(4)
transform_matrix[:3, :3] = rotation_matrix * scale
transform_matrix[:3, 3] = translation
return transform_matrix
def write_output(primitives, name):
out_json = {}
out_json['operation'] = 0
out_json['type'] = 1
out_json['scene_id'] = None
new_group = []
model_scene = trimesh.Scene()
color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0])
color_map = (np.array(color_map) * 255).astype("uint8")
for idx, (scale, rotation, translation, type_code) in enumerate(zip(
primitives['scale'].squeeze().cpu().numpy(),
primitives['rotation'].squeeze().cpu().numpy(),
primitives['translation'].squeeze().cpu().numpy(),
primitives['type_code'].squeeze().cpu().numpy()
)):
if type_code == -1:
break
bs_name = CODE_SHAPE[type_code]
new_block = {}
new_block['type_id'] = shapename_map[bs_name]
new_block['data'] = {}
new_block['data']['location'] = translation.tolist()
new_block['data']['rotation'] = euler_to_quat(rotation).tolist()
new_block['data']['scale'] = scale.tolist()
new_block['data']['color'] = ['808080']
new_group.append(new_block)
trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation)
bs = mesh_bs[bs_name].copy().apply_transform(trans)
new_vertex_colors = np.repeat(color_map[idx:idx+1], bs.visual.vertex_colors.shape[0], axis=0)
bs.visual.vertex_colors[:, :3] = new_vertex_colors
vertices = bs.vertices.copy()
vertices[:, 1] = bs.vertices[:, 2]
vertices[:, 2] = -bs.vertices[:, 1]
bs.vertices = vertices
model_scene.add_geometry(bs)
out_json['group'] = new_group
json_path = os.path.join(LOG_PATH, f'output_{name}.json')
with open(json_path, 'w') as json_file:
json.dump(out_json, json_file, indent=4)
glb_path = os.path.join(LOG_PATH, f'output_{name}.glb')
model_scene.export(glb_path)
return glb_path, out_json
@torch.no_grad()
def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False, do_marching_cubes=False, postprocess='none'):
t1 = time.time()
set_seed(sample_seed)
input_mesh = trimesh.load(input_3d, force='mesh')
# scale mesh
vertices = input_mesh.vertices
bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6
input_mesh.vertices = vertices
pc_list, mesh_list = process_mesh_to_surface_pc(
[input_mesh],
marching_cubes=do_marching_cubes,
dilated_offset=dilated_offset
)
pc_normal = pc_list[0] # 10000, 6
mesh = mesh_list[0]
pc_coor = pc_normal[:, :3]
normals = pc_normal[:, 3:]
if dilated_offset > 0:
# scale mesh and pc
vertices = mesh.vertices
bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6
mesh.vertices = vertices
pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
pc_coor = pc_coor / (bounds[1] - bounds[0]).max() * 1.6
input_save_name = os.path.join(LOG_PATH, f'processed_{os.path.basename(input_3d)}')
mesh.export(input_save_name)
assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), 'normals should be unit vectors, something wrong'
normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
input_pc = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None]
with accelerator.autocast():
if postprocess == 'postprocess1':
recon_primitives, mask = transformer.generate_w_recon_loss(pc=input_pc, temperature=temperature, single_directional=True)
else:
recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature)
output_glb, output_json = write_output(recon_primitives, os.path.basename(input_3d)[:-4])
return input_save_name, output_glb, output_json
dilated_offset = 0.015
do_marching_cubes = True
postprocess = 'postprocess1'
for input_3d in input_3ds:
print(f"processing: {input_3d}")
preprocess_model_obj, output_model_obj, output_model_json = do_inference(
input_3d,
dilated_offset=dilated_offset,
do_marching_cubes=do_marching_cubes,
postprocess=postprocess
)

113
eval.py Normal file
View File

@ -0,0 +1,113 @@
import argparse
from collections import defaultdict
import glob
import json
import os
import numpy as np
import trimesh
import point_cloud_utils as pcu
from tqdm import tqdm
import numpy as np
def voxelize(points, voxel_size=0.1):
"""
Converts a set of 3D points to a voxel grid representation.
Points are quantized to the nearest voxel center.
Parameters:
- points: (N, 3) numpy array of 3D points
- voxel_size: the size of each voxel (adjust this depending on your data)
Returns:
- voxels: Set of unique voxel coordinates
"""
# Quantize the points to the nearest voxel
quantized_points = np.floor(points / voxel_size).astype(int)
# Use a set to get unique voxel coordinates
voxels = set(map(tuple, quantized_points))
return voxels
def calculate_iou(model_vox, target_vox, voxel_size=0.1):
"""
Calculate the IoU (Intersection over Union) between two point clouds.
Parameters:
- model_vox: (N, 3) numpy array of the first point cloud
- target_vox: (M, 3) numpy array of the second point cloud
- voxel_size: Size of the voxels (default is 0.1)
Returns:
- iou: Intersection over Union (IoU) score
"""
# Voxelize both point clouds
model_voxels = voxelize(model_vox, voxel_size)
target_voxels = voxelize(target_vox, voxel_size)
# Calculate intersection and union
intersection = len(model_voxels.intersection(target_voxels))
union = len(model_voxels.union(target_voxels))
# Compute IoU
iou = intersection / union if union > 0 else 0.0
return iou
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, default="./results/infer/PointClouds")
parser.add_argument("--target_dir", type=str, default="./data/test_pc")
parser.add_argument("--detail", action="store_true")
args = parser.parse_args()
if not os.path.exists(args.input_dir) or not os.path.exists(args.target_dir):
print("Invalid input!")
exit(1)
model_prefix = os.path.join(args.input_dir, "*.ply")
model_path_list = sorted(list(glob.glob(model_prefix)))
distance_json = {}
distance_list = []
emu_distance_list = []
hausdorff_distance_list = []
voxel_iou_list = []
for model_path in tqdm(model_path_list):
model_name = os.path.basename(model_path)
target_path = os.path.join(args.target_dir, model_name)
if not os.path.exists(target_path):
print(f"{target_path}: not found!")
exit(1)
model_pc = np.array(trimesh.load(model_path).vertices)
target_pc = np.array(trimesh.load(target_path).vertices)
distance = pcu.chamfer_distance(model_pc, target_pc)
model_pc_downsampled = model_pc[np.random.choice(model_pc.shape[0], 1000, replace=False)]
target_pc_downsampled = target_pc[np.random.choice(target_pc.shape[0], 1000, replace=False)]
emu_distance, _ = pcu.earth_movers_distance(model_pc_downsampled, target_pc_downsampled)
hausdorff_distance = pcu.hausdorff_distance(model_pc, target_pc)
iou = calculate_iou(model_pc, target_pc, voxel_size=1/32.)
distance_list.append(distance)
emu_distance_list.append(emu_distance)
hausdorff_distance_list.append(hausdorff_distance)
voxel_iou_list.append(iou)
model_id = os.path.splitext(model_name)[0]
distance_json[model_id] = distance
print(f"{model_id}: chamfer distance: {distance:.3f}, earth movers distance: {emu_distance:.3f}, hausdorff distance: {hausdorff_distance:.3f}, voxel IoU: {iou:.3f}")
distance_json["mean"] = np.mean(distance_list)
distance_json["mean_emu"] = np.mean(emu_distance_list)
distance_json["mean_hausdorff"] = np.mean(hausdorff_distance_list)
distance_json["mean_voxel_iou"] = np.mean(voxel_iou_list)
print(f"mean chamfer distance: {np.mean(distance_list)}")
print(f"mean earth movers distance: {np.mean(emu_distance_list)}")
print(f"mean hausdorff distance: {np.mean(hausdorff_distance_list)}")
print(f"mean voxel IoU: {np.mean(voxel_iou_list)}")
with open(os.path.join(args.input_dir, "distance.json"), "w") as json_file:
json.dump(distance_json, json_file, indent=4)

172
infer.py Executable file
View File

@ -0,0 +1,172 @@
import argparse
import glob
import json
import yaml
from pathlib import Path
import os
import re
import numpy as np
from scipy.spatial.transform import Rotation
from tqdm import tqdm
import torch
import trimesh
from primitive_anything.primitive_dataset import create_dataset
from primitive_anything.utils import torch_to, count_parameters
from primitive_anything.utils.logger import create_logger, print_log
CODE_SHAPE = {
0: 'SM_GR_BS_CubeBevel_001.ply',
1: 'SM_GR_BS_SphereSharp_001.ply',
2: 'SM_GR_BS_CylinderSharp_001.ply',
}
shapename_map = {
'SM_GR_BS_CubeBevel_001.ply': 1101002001034001,
'SM_GR_BS_SphereSharp_001.ply': 1101002001034010,
'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002,
}
bs_dir = 'data/basic_shapes_norm'
mesh_bs = {}
for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')):
bs_name = os.path.basename(bs_path)
bs = trimesh.load(bs_path)
bs.visual.uv = np.clip(bs.visual.uv, 0, 1)
bs.visual = bs.visual.to_color()
mesh_bs[bs_name] = bs
def create_model(cfg_model):
kwargs = cfg_model
name = kwargs.pop('name')
model = get_model(name)(**kwargs)
print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs))
return model
from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete
def get_model(name):
return {
'discrete': PrimitiveTransformerDiscrete,
}[name]
def euler_to_quat(euler):
return Rotation.from_euler('XYZ', euler, degrees=True).as_quat()
def rotvec_to_quat(rotvec):
return Rotation.from_rotvec(rotvec, degrees=True).as_quat()
def SRT_quat_to_matrix(scale, quat, translation):
rotation_matrix = Rotation.from_quat(quat).as_matrix()
transform_matrix = np.eye(4)
transform_matrix[:3, :3] = rotation_matrix * scale
transform_matrix[:3, 3] = translation
return transform_matrix
def write_json(primitives, shapename_map, out_path):
out_json = {}
out_json['operation'] = 0
out_json['type'] = 1
out_json['scene_id'] = None
new_group = []
model_scene = trimesh.Scene()
for scale, rotation, translation, type_code in zip(
primitives['scale'].squeeze().cpu().numpy(),
primitives['rotation'].squeeze().cpu().numpy(),
primitives['translation'].squeeze().cpu().numpy(),
primitives['type_code'].squeeze().cpu().numpy()
):
if type_code == -1:
break
bs_name = CODE_SHAPE[type_code]
new_block = {}
new_block['type_id'] = shapename_map[bs_name]
new_block['data'] = {}
new_block['data']['location'] = translation.tolist()
new_block['data']['rotation'] = euler_to_quat(rotation).tolist()
new_block['data']['scale'] = scale.tolist()
new_block['data']['color'] = ['808080']
new_group.append(new_block)
if new_block['type_id'] == 1101002001034001:
cur_color = "#2FA9FF"
elif new_block['type_id'] == 1101002001034002:
cur_color = "#FFC203"
elif new_block['type_id'] == 1101002001034010:
cur_color = "#FF8A9C"
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip('#')
return np.array([
int(hex_color[0:2], 16), # R
int(hex_color[2:4], 16), # G
int(hex_color[4:6], 16), # B
], dtype=np.uint8)[None]
trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation)
bs = mesh_bs[bs_name].copy().apply_transform(trans)
new_vertex_colors = np.repeat(hex_to_rgb(cur_color), bs.visual.vertex_colors.shape[0], axis=0)
bs.visual.vertex_colors[:, :3] = new_vertex_colors
vertices = bs.vertices.copy()
vertices[:, 1] = bs.vertices[:, 2]
vertices[:, 2] = -bs.vertices[:, 1]
bs.vertices = vertices
model_scene.add_geometry(bs)
out_json['group'] = new_group
with open(out_path, 'w') as json_file:
json.dump(out_json, json_file, indent=4)
glb_path = out_path.replace('.json', '.glb')
model_scene.export(glb_path)
return glb_path, out_json
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='./configs/infer.yml', help='Config file path')
parser.add_argument('-ck', '--AR_ckpt', type=str, default='./ckpt/mesh-transformer.ckpt.60.pt')
parser.add_argument('-o', '--output', type=str, default='./results/infer')
parser.add_argument('--bs_dir', type=str, default='data/basic_shapes_norm')
parser.add_argument('--temperature', type=float, default=0.0)
args = parser.parse_args()
bs_names = []
for bs_path in glob.glob(os.path.join(args.bs_dir, '*.ply')):
bs_names.append(os.path.basename(bs_path))
with open(args.config, mode='r') as fp:
cfg = yaml.load(fp, Loader=yaml.FullLoader)
AR_checkpoint = torch.load(args.AR_ckpt)
os.makedirs(args.output, exist_ok=True)
json_result_folder = os.path.join(args.output, 'JsonResults')
os.makedirs(json_result_folder, exist_ok=True)
create_logger(Path(args.output))
dataset = create_dataset(cfg['dataset'])
transformer = create_model(cfg['model'])
transformer.load_state_dict(AR_checkpoint)
for item_i, item in tqdm(enumerate(dataset)):
pc = item.pop('pc')
item_filename = dataset.data_filename[item_i]
if torch.cuda.is_available():
pc = pc.cuda()
item = torch_to(item, torch.device('cuda'))
transformer = transformer.cuda()
recon_primitives, mask = transformer.generate(pc=pc.unsqueeze(0), temperature=args.temperature)
out_path = os.path.join(json_result_folder, os.path.basename(item_filename).replace('.ply', '.json'))
write_json(recon_primitives, shapename_map, out_path)

0
primitive_anything/__init__.py Executable file
View File

View File

@ -0,0 +1,51 @@
import os
from omegaconf import OmegaConf
import torch
from torch import nn
from .utils.misc import instantiate_from_config
from ..utils import default, exists
def load_model():
model_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), "shapevae-256.yaml"))
# print(model_config)
if hasattr(model_config, "model"):
model_config = model_config.model
ckpt_path = "./ckpt/shapevae-256.ckpt"
model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
# model = model.cuda()
model = model.eval()
return model
class ShapeConditioner(nn.Module):
def __init__(
self,
*,
dim_latent = None
):
super().__init__()
self.model = load_model()
self.dim_model_out = 768
dim_latent = default(dim_latent, self.dim_model_out)
self.dim_latent = dim_latent
def forward(
self,
shape = None,
shape_embed = None,
):
assert exists(shape) ^ exists(shape_embed)
if not exists(shape_embed):
point_feature = self.model.encode_latents(shape)
shape_latents = self.model.to_shape_latents(point_feature[:, 1:])
shape_head = point_feature[:, 0:1]
shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-1)
# shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-2) # cat tmp
return shape_head, shape_embed

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from .volume import generate_dense_grid_points
from .mesh import (
MeshOutput,
save_obj,
savemeshtes2
)

View File

@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
import os
import cv2
import numpy as np
import PIL.Image
from typing import Optional
import trimesh
def save_obj(pointnp_px3, facenp_fx3, fname):
fid = open(fname, "w")
write_str = ""
for pidx, p in enumerate(pointnp_px3):
pp = p
write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
for i, f in enumerate(facenp_fx3):
f1 = f + 1
write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
fid.write(write_str)
fid.close()
return
def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
fol, na = os.path.split(fname)
na, _ = os.path.splitext(na)
matname = "%s/%s.mtl" % (fol, na)
fid = open(matname, "w")
fid.write("newmtl material_0\n")
fid.write("Kd 1 1 1\n")
fid.write("Ka 0 0 0\n")
fid.write("Ks 0.4 0.4 0.4\n")
fid.write("Ns 10\n")
fid.write("illum 2\n")
fid.write("map_Kd %s.png\n" % na)
fid.close()
####
fid = open(fname, "w")
fid.write("mtllib %s.mtl\n" % na)
for pidx, p in enumerate(pointnp_px3):
pp = p
fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
for pidx, p in enumerate(tcoords_px2):
pp = p
fid.write("vt %f %f\n" % (pp[0], pp[1]))
fid.write("usemtl material_0\n")
for i, f in enumerate(facenp_fx3):
f1 = f + 1
f2 = facetex_fx3[i] + 1
fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
fid.close()
PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
os.path.join(fol, "%s.png" % na))
return
class MeshOutput(object):
def __init__(self,
mesh_v: np.ndarray,
mesh_f: np.ndarray,
vertex_colors: Optional[np.ndarray] = None,
uvs: Optional[np.ndarray] = None,
mesh_tex_idx: Optional[np.ndarray] = None,
tex_map: Optional[np.ndarray] = None):
self.mesh_v = mesh_v
self.mesh_f = mesh_f
self.vertex_colors = vertex_colors
self.uvs = uvs
self.mesh_tex_idx = mesh_tex_idx
self.tex_map = tex_map
def contain_uv_texture(self):
return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
def contain_vertex_colors(self):
return self.vertex_colors is not None
def export(self, fname):
if self.contain_uv_texture():
savemeshtes2(
self.mesh_v,
self.uvs,
self.mesh_f,
self.mesh_tex_idx,
self.tex_map,
fname
)
elif self.contain_vertex_colors():
mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
mesh_obj.export(fname)
else:
save_obj(
self.mesh_v,
self.mesh_f,
fname
)

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
import numpy as np
def generate_dense_grid_points(bbox_min: np.ndarray,
bbox_max: np.ndarray,
octree_depth: int,
indexing: str = "ij"):
length = bbox_max - bbox_min
num_cells = np.exp2(octree_depth)
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
xyz = np.stack((xs, ys, zs), axis=-1)
xyz = xyz.reshape(-1, 3)
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
return xyz, grid_size, length

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,483 @@
# -*- coding: utf-8 -*-
from omegaconf import DictConfig
from typing import List, Tuple, Dict, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from einops import rearrange
from diffusers.schedulers import (
DDPMScheduler,
DDIMScheduler,
KarrasVeScheduler,
DPMSolverMultistepScheduler
)
from ...utils import instantiate_from_config
# from ..tsal.tsal_base import ShapeAsLatentPLModule
from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
from .inference_utils import ddim_sample
SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class ASLDiffuser(pl.LightningModule):
first_stage_model: Optional[AlignedShapeAsLatentPLModule]
# cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
model: nn.Module
def __init__(self, *,
first_stage_config,
denoiser_cfg,
scheduler_cfg,
optimizer_cfg,
loss_cfg,
first_stage_key: str = "surface",
cond_stage_key: str = "image",
cond_stage_trainable: bool = True,
scale_by_std: bool = False,
z_scale_factor: float = 1.0,
ckpt_path: Optional[str] = None,
ignore_keys: Union[Tuple[str], List[str]] = ()):
super().__init__()
self.first_stage_key = first_stage_key
self.cond_stage_key = cond_stage_key
self.cond_stage_trainable = cond_stage_trainable
# 1. initialize first stage.
# Note: the condition model contained in the first stage model.
self.first_stage_config = first_stage_config
self.first_stage_model = None
# self.instantiate_first_stage(first_stage_config)
# 2. initialize conditional stage
# self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_model = {
"image": self.encode_image,
"image_unconditional_embedding": self.empty_img_cond,
"text": self.encode_text,
"text_unconditional_embedding": self.empty_text_cond,
"surface": self.encode_surface,
"surface_unconditional_embedding": self.empty_surface_cond,
}
# 3. diffusion model
self.model = instantiate_from_config(
denoiser_cfg, device=None, dtype=None
)
self.optimizer_cfg = optimizer_cfg
# 4. scheduling strategy
self.scheduler_cfg = scheduler_cfg
self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
# 5. loss configures
self.loss_cfg = loss_cfg
self.scale_by_std = scale_by_std
if scale_by_std:
self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
else:
self.z_scale_factor = z_scale_factor
self.ckpt_path = ckpt_path
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
self.first_stage_model = self.first_stage_model.to(self.device)
# def instantiate_cond_stage(self, config):
# if not self.cond_stage_trainable:
# if config == "__is_first_stage__":
# print("Using first stage also as cond stage.")
# self.cond_stage_model = self.first_stage_model
# elif config == "__is_unconditional__":
# print(f"Training {self.__class__.__name__} as an unconditional model.")
# self.cond_stage_model = None
# # self.be_unconditional = True
# else:
# model = instantiate_from_config(config)
# self.cond_stage_model = model.eval()
# self.cond_stage_model.train = disabled_train
# for param in self.cond_stage_model.parameters():
# param.requires_grad = False
# else:
# assert config != "__is_first_stage__"
# assert config != "__is_unconditional__"
# model = instantiate_from_config(config)
# self.cond_stage_model = model
def init_from_ckpt(self, path, ignore_keys=()):
state_dict = torch.load(path, map_location="cpu")["state_dict"]
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
@property
def zero_rank(self):
if self._trainer:
zero_rank = self.trainer.local_rank == 0
else:
zero_rank = True
return zero_rank
def configure_optimizers(self) -> Tuple[List, List]:
lr = self.learning_rate
trainable_parameters = list(self.model.parameters())
# if the conditional encoder is trainable
# if self.cond_stage_trainable:
# conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad]
# trainable_parameters += conditioner_params
# print(f"number of trainable conditional parameters: {len(conditioner_params)}.")
if self.optimizer_cfg is None:
optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
schedulers = []
else:
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
scheduler_func = instantiate_from_config(
self.optimizer_cfg.scheduler,
max_decay_steps=self.trainer.max_steps,
lr_max=lr
)
scheduler = {
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
"interval": "step",
"frequency": 1
}
optimizers = [optimizer]
schedulers = [scheduler]
return optimizers, schedulers
@torch.no_grad()
def encode_text(self, text):
b = text.shape[0]
text_tokens = rearrange(text, "b t l -> (b t) l")
text_embed = self.first_stage_model.model.encode_text_embed(text_tokens)
text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
text_embed = text_embed.mean(dim=1)
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
return text_embed
@torch.no_grad()
def encode_image(self, img):
return self.first_stage_model.model.encode_image_embed(img)
@torch.no_grad()
def encode_surface(self, surface):
return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False)
@torch.no_grad()
def empty_text_cond(self, cond):
return torch.zeros_like(cond, device=cond.device)
@torch.no_grad()
def empty_img_cond(self, cond):
return torch.zeros_like(cond, device=cond.device)
@torch.no_grad()
def empty_surface_cond(self, cond):
return torch.zeros_like(cond, device=cond.device)
@torch.no_grad()
def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
z_q = self.first_stage_model.encode(surface, sample_posterior)
z_q = self.z_scale_factor * z_q
return z_q
@torch.no_grad()
def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
z_q = 1. / self.z_scale_factor * z_q
latents = self.first_stage_model.decode(z_q, **kwargs)
return latents
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
and batch_idx == 0 and self.ckpt_path is None:
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
z_q = self.encode_first_stage(batch[self.first_stage_key])
z = z_q.detach()
del self.z_scale_factor
self.register_buffer("z_scale_factor", 1. / z.flatten().std())
print(f"setting self.z_scale_factor to {self.z_scale_factor}")
print("### USING STD-RESCALING ###")
def compute_loss(self, model_outputs, split):
"""
Args:
model_outputs (dict):
- x_0:
- noise:
- noise_prior:
- noise_pred:
- noise_pred_prior:
split (str):
Returns:
"""
pred = model_outputs["pred"]
if self.noise_scheduler.prediction_type == "epsilon":
target = model_outputs["noise"]
elif self.noise_scheduler.prediction_type == "sample":
target = model_outputs["x_0"]
else:
raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
if self.loss_cfg.loss_type == "l1":
simple = F.l1_loss(pred, target, reduction="mean")
elif self.loss_cfg.loss_type in ["mse", "l2"]:
simple = F.mse_loss(pred, target, reduction="mean")
else:
raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
total_loss = simple
loss_dict = {
f"{split}/total_loss": total_loss.clone().detach(),
f"{split}/simple": simple.detach(),
}
return total_loss, loss_dict
def forward(self, batch):
"""
Args:
batch:
Returns:
"""
if self.first_stage_model is None:
self.instantiate_first_stage(self.first_stage_config)
latents = self.encode_first_stage(batch[self.first_stage_key])
# conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1)
mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1
conditions = conditions * mask.to(conditions)
# Sample noise that we"ll add to the latents
# [batch_size, n_token, latent_dim]
noise = torch.randn_like(latents)
bs = latents.shape[0]
# Sample a random timestep for each motion
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(bs,),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
# diffusion model forward
noise_pred = self.model(noisy_z, timesteps, conditions)
diffusion_outputs = {
"x_0": noisy_z,
"noise": noise,
"pred": noise_pred
}
return diffusion_outputs
def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface (torch.FloatTensor):
- image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
- depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
- normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
- text (list of str):
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
diffusion_outputs = self(batch)
loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
return loss
def validation_step(self, batch: Dict[str, torch.FloatTensor],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface_pc (torch.FloatTensor): [n_pts, 4]
- surface_feats (torch.FloatTensor): [n_pts, c]
- text (list of str):
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
diffusion_outputs = self(batch)
loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
return loss
@torch.no_grad()
def sample(self,
batch: Dict[str, Union[torch.FloatTensor, List[str]]],
sample_times: int = 1,
steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
eta: float = 0.0,
return_intermediates: bool = False, **kwargs):
if self.first_stage_model is None:
self.instantiate_first_stage(self.first_stage_config)
if steps is None:
steps = self.scheduler_cfg.num_inference_steps
if guidance_scale is None:
guidance_scale = self.scheduler_cfg.guidance_scale
do_classifier_free_guidance = guidance_scale > 0
# conditional encode
xc = batch[self.cond_stage_key]
# cond = self.cond_stage_model[self.cond_stage_key](xc)
cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1)
if do_classifier_free_guidance:
"""
Note: There are two kinds of uncond for text.
1: using "" as uncond text; (in SAL diffusion)
2: zeros_like(cond) as uncond text; (in MDM)
"""
# un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond)
# un_cond = torch.zeros_like(cond, device=cond.device)
cond = torch.cat([un_cond, cond], dim=0)
outputs = []
latents = None
if not return_intermediates:
for _ in range(sample_times):
sample_loop = ddim_sample(
self.denoise_scheduler,
self.model,
shape=self.first_stage_model.latent_shape,
cond=cond,
steps=steps,
guidance_scale=guidance_scale,
do_classifier_free_guidance=do_classifier_free_guidance,
device=self.device,
eta=eta,
disable_prog=not self.zero_rank
)
for sample, t in sample_loop:
latents = sample
outputs.append(self.decode_first_stage(latents, **kwargs))
else:
sample_loop = ddim_sample(
self.denoise_scheduler,
self.model,
shape=self.first_stage_model.latent_shape,
cond=cond,
steps=steps,
guidance_scale=guidance_scale,
do_classifier_free_guidance=do_classifier_free_guidance,
device=self.device,
eta=eta,
disable_prog=not self.zero_rank
)
iter_size = steps // sample_times
i = 0
for sample, t in sample_loop:
latents = sample
if i % iter_size == 0 or i == steps - 1:
outputs.append(self.decode_first_stage(latents, **kwargs))
i += 1
return outputs

View File

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from typing import Optional
from diffusers.models.embeddings import Timesteps
import math
from ..modules.transformer_blocks import MLP
from ..modules.diffusion_transformer import UNetDiffusionTransformer
class ConditionalASLUDTDenoiser(nn.Module):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
input_channels: int,
output_channels: int,
n_ctx: int,
width: int,
layers: int,
heads: int,
context_dim: int,
context_ln: bool = True,
skip_ln: bool = False,
init_scale: float = 0.25,
flip_sin_to_cos: bool = False,
use_checkpoint: bool = False):
super().__init__()
self.use_checkpoint = use_checkpoint
init_scale = init_scale * math.sqrt(1.0 / width)
self.backbone = UNetDiffusionTransformer(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
layers=layers,
heads=heads,
skip_ln=skip_ln,
init_scale=init_scale,
use_checkpoint=use_checkpoint
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
# timestep embedding
self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
self.time_proj = MLP(
device=device, dtype=dtype, width=width, init_scale=init_scale
)
self.context_embed = nn.Sequential(
nn.LayerNorm(context_dim, device=device, dtype=dtype),
nn.Linear(context_dim, width, device=device, dtype=dtype),
)
if context_ln:
self.context_embed = nn.Sequential(
nn.LayerNorm(context_dim, device=device, dtype=dtype),
nn.Linear(context_dim, width, device=device, dtype=dtype),
)
else:
self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
def forward(self,
model_input: torch.FloatTensor,
timestep: torch.LongTensor,
context: torch.FloatTensor):
r"""
Args:
model_input (torch.FloatTensor): [bs, n_data, c]
timestep (torch.LongTensor): [bs,]
context (torch.FloatTensor): [bs, context_tokens, c]
Returns:
sample (torch.FloatTensor): [bs, n_data, c]
"""
_, n_data, _ = model_input.shape
# 1. time
t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
# 2. conditions projector
context = self.context_embed(context)
# 3. denoiser
x = self.input_proj(model_input)
x = torch.cat([t_emb, context, x], dim=1)
x = self.backbone(x)
x = self.ln_post(x)
x = x[:, -n_data:]
sample = self.output_proj(x)
return sample

View File

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class BaseDenoiser(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, t, context):
raise NotImplementedError

View File

@ -0,0 +1,393 @@
# -*- coding: utf-8 -*-
from omegaconf import DictConfig
from typing import List, Tuple, Dict, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from diffusers.schedulers import (
DDPMScheduler,
DDIMScheduler,
KarrasVeScheduler,
DPMSolverMultistepScheduler
)
from ...utils import instantiate_from_config
from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
from .inference_utils import ddim_sample
SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class ClipASLDiffuser(pl.LightningModule):
first_stage_model: Optional[AlignedShapeAsLatentPLModule]
cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
model: nn.Module
def __init__(self, *,
first_stage_config,
cond_stage_config,
denoiser_cfg,
scheduler_cfg,
optimizer_cfg,
loss_cfg,
first_stage_key: str = "surface",
cond_stage_key: str = "image",
scale_by_std: bool = False,
z_scale_factor: float = 1.0,
ckpt_path: Optional[str] = None,
ignore_keys: Union[Tuple[str], List[str]] = ()):
super().__init__()
self.first_stage_key = first_stage_key
self.cond_stage_key = cond_stage_key
# 1. lazy initialize first stage
self.instantiate_first_stage(first_stage_config)
# 2. initialize conditional stage
self.instantiate_cond_stage(cond_stage_config)
# 3. diffusion model
self.model = instantiate_from_config(
denoiser_cfg, device=None, dtype=None
)
self.optimizer_cfg = optimizer_cfg
# 4. scheduling strategy
self.scheduler_cfg = scheduler_cfg
self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
# 5. loss configures
self.loss_cfg = loss_cfg
self.scale_by_std = scale_by_std
if scale_by_std:
self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
else:
self.z_scale_factor = z_scale_factor
self.ckpt_path = ckpt_path
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def instantiate_non_trainable_model(self, config):
model = instantiate_from_config(config)
model = model.eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
return model
def instantiate_first_stage(self, first_stage_config):
self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config)
self.first_stage_model.set_shape_model_only()
def instantiate_cond_stage(self, cond_stage_config):
self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config)
def init_from_ckpt(self, path, ignore_keys=()):
state_dict = torch.load(path, map_location="cpu")["state_dict"]
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
@property
def zero_rank(self):
if self._trainer:
zero_rank = self.trainer.local_rank == 0
else:
zero_rank = True
return zero_rank
def configure_optimizers(self) -> Tuple[List, List]:
lr = self.learning_rate
trainable_parameters = list(self.model.parameters())
if self.optimizer_cfg is None:
optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
schedulers = []
else:
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
scheduler_func = instantiate_from_config(
self.optimizer_cfg.scheduler,
max_decay_steps=self.trainer.max_steps,
lr_max=lr
)
scheduler = {
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
"interval": "step",
"frequency": 1
}
optimizers = [optimizer]
schedulers = [scheduler]
return optimizers, schedulers
@torch.no_grad()
def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
z_q = self.first_stage_model.encode(surface, sample_posterior)
z_q = self.z_scale_factor * z_q
return z_q
@torch.no_grad()
def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
z_q = 1. / self.z_scale_factor * z_q
latents = self.first_stage_model.decode(z_q, **kwargs)
return latents
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
and batch_idx == 0 and self.ckpt_path is None:
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
z_q = self.encode_first_stage(batch[self.first_stage_key])
z = z_q.detach()
del self.z_scale_factor
self.register_buffer("z_scale_factor", 1. / z.flatten().std())
print(f"setting self.z_scale_factor to {self.z_scale_factor}")
print("### USING STD-RESCALING ###")
def compute_loss(self, model_outputs, split):
"""
Args:
model_outputs (dict):
- x_0:
- noise:
- noise_prior:
- noise_pred:
- noise_pred_prior:
split (str):
Returns:
"""
pred = model_outputs["pred"]
if self.noise_scheduler.prediction_type == "epsilon":
target = model_outputs["noise"]
elif self.noise_scheduler.prediction_type == "sample":
target = model_outputs["x_0"]
else:
raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
if self.loss_cfg.loss_type == "l1":
simple = F.l1_loss(pred, target, reduction="mean")
elif self.loss_cfg.loss_type in ["mse", "l2"]:
simple = F.mse_loss(pred, target, reduction="mean")
else:
raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
total_loss = simple
loss_dict = {
f"{split}/total_loss": total_loss.clone().detach(),
f"{split}/simple": simple.detach(),
}
return total_loss, loss_dict
def forward(self, batch):
"""
Args:
batch:
Returns:
"""
latents = self.encode_first_stage(batch[self.first_stage_key])
conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
# Sample noise that we"ll add to the latents
# [batch_size, n_token, latent_dim]
noise = torch.randn_like(latents)
bs = latents.shape[0]
# Sample a random timestep for each motion
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(bs,),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
# diffusion model forward
noise_pred = self.model(noisy_z, timesteps, conditions)
diffusion_outputs = {
"x_0": noisy_z,
"noise": noise,
"pred": noise_pred
}
return diffusion_outputs
def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface (torch.FloatTensor):
- image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
- depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
- normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
- text (list of str):
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
diffusion_outputs = self(batch)
loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
return loss
def validation_step(self, batch: Dict[str, torch.FloatTensor],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface_pc (torch.FloatTensor): [n_pts, 4]
- surface_feats (torch.FloatTensor): [n_pts, c]
- text (list of str):
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
diffusion_outputs = self(batch)
loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
return loss
@torch.no_grad()
def sample(self,
batch: Dict[str, Union[torch.FloatTensor, List[str]]],
sample_times: int = 1,
steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
eta: float = 0.0,
return_intermediates: bool = False, **kwargs):
if steps is None:
steps = self.scheduler_cfg.num_inference_steps
if guidance_scale is None:
guidance_scale = self.scheduler_cfg.guidance_scale
do_classifier_free_guidance = guidance_scale > 0
# conditional encode
xc = batch[self.cond_stage_key]
# print(self.first_stage_model.device, self.cond_stage_model.device, self.device)
cond = self.cond_stage_model(xc)
if do_classifier_free_guidance:
un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
cond = torch.cat([un_cond, cond], dim=0)
outputs = []
latents = None
if not return_intermediates:
for _ in range(sample_times):
sample_loop = ddim_sample(
self.denoise_scheduler,
self.model,
shape=self.first_stage_model.latent_shape,
cond=cond,
steps=steps,
guidance_scale=guidance_scale,
do_classifier_free_guidance=do_classifier_free_guidance,
device=self.device,
eta=eta,
disable_prog=not self.zero_rank
)
for sample, t in sample_loop:
latents = sample
outputs.append(self.decode_first_stage(latents, **kwargs))
else:
sample_loop = ddim_sample(
self.denoise_scheduler,
self.model,
shape=self.first_stage_model.latent_shape,
cond=cond,
steps=steps,
guidance_scale=guidance_scale,
do_classifier_free_guidance=do_classifier_free_guidance,
device=self.device,
eta=eta,
disable_prog=not self.zero_rank
)
iter_size = steps // sample_times
i = 0
for sample, t in sample_loop:
latents = sample
if i % iter_size == 0 or i == steps - 1:
outputs.append(self.decode_first_stage(latents, **kwargs))
i += 1
return outputs

View File

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
import torch
from tqdm import tqdm
from typing import Tuple, List, Union, Optional
from diffusers.schedulers import DDIMScheduler
__all__ = ["ddim_sample"]
def ddim_sample(ddim_scheduler: DDIMScheduler,
diffusion_model: torch.nn.Module,
shape: Union[List[int], Tuple[int]],
cond: torch.FloatTensor,
steps: int,
eta: float = 0.0,
guidance_scale: float = 3.0,
do_classifier_free_guidance: bool = True,
generator: Optional[torch.Generator] = None,
device: torch.device = "cuda:0",
disable_prog: bool = True):
assert steps > 0, f"{steps} must > 0."
# init latents
bsz = cond.shape[0]
if do_classifier_free_guidance:
bsz = bsz // 2
latents = torch.randn(
(bsz, *shape),
generator=generator,
device=cond.device,
dtype=cond.dtype,
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * ddim_scheduler.init_noise_sigma
# set timesteps
ddim_scheduler.set_timesteps(steps)
timesteps = ddim_scheduler.timesteps.to(device)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, and between [0, 1]
extra_step_kwargs = {
"eta": eta,
"generator": generator
}
# reverse
for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2)
if do_classifier_free_guidance
else latents
)
# latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# text_embeddings_for_guidance = encoder_hidden_states.chunk(
# 2)[1] if do_classifier_free_guidance else encoder_hidden_states
# compute the previous noisy sample x_t -> x_t-1
latents = ddim_scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
yield latents, t
def karra_sample():
pass

View File

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
from .clip import CLIPEncoder

View File

@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
import torch
import numpy as np
from PIL import Image
from dataclasses import dataclass
from torchvision.transforms import Normalize
from transformers import CLIPModel, CLIPTokenizer
from transformers.utils import ModelOutput
from typing import Iterable, Optional, Union, List
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
@dataclass
class CLIPEmbedOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
embeds: torch.FloatTensor = None
class CLIPEncoder(torch.nn.Module):
def __init__(self, model_path="openai/clip-vit-base-patch32"):
super().__init__()
# Load the CLIP model and processor
self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.model.training = False
for p in self.model.parameters():
p.requires_grad = False
@torch.no_grad()
def encode_image(self, images: Iterable[Optional[ImageType]]):
pixel_values = self.image_preprocess(images)
vision_outputs = self.model.vision_model(pixel_values=pixel_values)
pooler_output = vision_outputs[1] # pooled_output
image_features = self.model.visual_projection(pooler_output)
visual_embeds = CLIPEmbedOutput(
last_hidden_state=vision_outputs.last_hidden_state,
pooler_output=pooler_output,
embeds=image_features
)
return visual_embeds
@torch.no_grad()
def encode_text(self, texts: List[str]):
text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
text_outputs = self.model.text_model(input_ids=text_inputs)
pooler_output = text_outputs[1] # pooled_output
text_features = self.model.text_projection(pooler_output)
text_embeds = CLIPEmbedOutput(
last_hidden_state=text_outputs.last_hidden_state,
pooler_output=pooler_output,
embeds=text_features
)
return text_embeds
def forward(self,
images: Iterable[Optional[ImageType]],
texts: List[str]):
visual_embeds = self.encode_image(images)
text_embeds = self.encode_text(texts)
return visual_embeds, text_embeds

View File

@ -0,0 +1,562 @@
# -*- coding: utf-8 -*-
import os
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import CLIPModel, CLIPTokenizer
from collections import OrderedDict
from ...data.transforms import RandomResize
class AbstractEncoder(nn.Module):
embedding_dim: int
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key="class"):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
def forward(self, batch, key=None):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
c = self.embedding(c)
return c
class FrozenCLIPTextEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(
self,
version="openai/clip-vit-large-patch14",
tokenizer_version=None,
device="cuda",
max_length=77,
zero_embedding_radio: float = 0.1,
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
self.device = device
self.max_length = max_length
self.zero_embedding_radio = zero_embedding_radio
self.clip_dict = OrderedDict()
self.clip_name = os.path.split(version)[-1]
transformer = CLIPModel.from_pretrained(version).text_model
for param in transformer.parameters():
param.requires_grad = False
self.clip_dict[self.clip_name] = transformer
self._move_flag = False
@property
def clip(self):
return self.clip_dict[self.clip_name]
def move(self):
if self._move_flag:
return
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
self._move_flag = True
def unconditional_embedding(self, batch_size):
empty_text = [""] * batch_size
empty_z = self.forward(empty_text)
return empty_z
def forward(self, text):
self.move()
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.clip(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
batch_size = len(text)
batch_mask = torch.rand((batch_size,))
for i in range(batch_size):
if batch_mask[i] < self.zero_embedding_radio:
text[i] = ""
return self(text)
class FrozenAlignedCLIPTextEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(
self,
version="openai/clip-vit-large-patch14",
tokenizer_version=None,
device="cuda",
max_length=77,
zero_embedding_radio: float = 0.1,
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
self.device = device
self.max_length = max_length
self.zero_embedding_radio = zero_embedding_radio
self.clip_dict = OrderedDict()
self.clip_name = os.path.split(version)[-1]
transformer = CLIPModel.from_pretrained(version).text_model
for param in transformer.parameters():
param.requires_grad = False
self.clip_dict[self.clip_name] = transformer
self._move_flag = False
@property
def clip(self):
return self.clip_dict[self.clip_name]
def move(self):
if self._move_flag:
return
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
self._move_flag = True
def unconditional_embedding(self, batch_size):
empty_text = [""] * batch_size
empty_z = self.forward(empty_text)
return empty_z
def forward(self, text):
self.move()
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.clip(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
batch_size = len(text)
batch_mask = torch.rand((batch_size,))
for i in range(batch_size):
if batch_mask[i] < self.zero_embedding_radio:
text[i] = ""
return self(text)
class FrozenCLIPImageEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
zero_embedding_radio=0.1,
normalize_embedding=True,
num_projection_vector=0,
linear_mapping_bias=True,
reverse_visual_projection=False,
):
super().__init__()
self.device = device
self.clip_dict = OrderedDict()
self.clip_name = os.path.split(version)[-1]
clip_model = CLIPModel.from_pretrained(version)
clip_model.text_model = None
clip_model.text_projection = None
clip_model = clip_model.eval()
for param in self.parameters():
param.requires_grad = False
self.clip_dict[self.clip_name] = clip_model
self.transform = transforms.Compose(
[
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
transforms.CenterCrop(224), # crop a (224, 224) square
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
self.zero_embedding_radio = zero_embedding_radio
self.num_projection_vector = num_projection_vector
self.reverse_visual_projection = reverse_visual_projection
self.normalize_embedding = normalize_embedding
embedding_dim = (
clip_model.visual_projection.in_features
if reverse_visual_projection
else clip_model.visual_projection.out_features
)
self.embedding_dim = embedding_dim
if self.num_projection_vector > 0:
self.projection = nn.Linear(
embedding_dim,
clip_model.visual_projection.out_features * num_projection_vector,
bias=linear_mapping_bias,
)
nn.init.normal_(self.projection.weight, std=embedding_dim ** -0.5)
self._move_flag = False
@property
def clip(self):
return self.clip_dict[self.clip_name]
def unconditional_embedding(self, batch_size):
zero = torch.zeros(
batch_size,
1,
self.embedding_dim,
device=self.device,
dtype=self.clip.visual_projection.weight.dtype,
)
if self.num_projection_vector > 0:
zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
return zero
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
if value_range is not None:
low, high = value_range
image = (image - low) / (high - low)
image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
if self.reverse_visual_projection:
z = self.clip.vision_model(self.transform(image))[1]
else:
z = self.clip.get_image_features(self.transform(image))
if self.normalize_embedding:
z = z / z.norm(dim=-1, keepdim=True)
if z.ndim == 2:
z = z.unsqueeze(dim=-2)
if zero_embedding_radio > 0:
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) < zero_embedding_radio
z = z * mask.to(z)
if self.num_projection_vector > 0:
z = self.projection(z).view(len(image), self.num_projection_vector, -1)
return z
def move(self):
if self._move_flag:
return
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
self._move_flag = True
def encode(self, image):
self.move()
return self(image, zero_embedding_radio=self.zero_embedding_radio)
class FrozenCLIPImageGridEmbedder(AbstractEncoder):
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
zero_embedding_radio=0.1,
):
super().__init__()
self.device = device
self.clip_dict = OrderedDict()
self.clip_name = os.path.split(version)[-1]
clip_model: CLIPModel = CLIPModel.from_pretrained(version)
clip_model.text_model = None
clip_model.text_projection = None
clip_model = clip_model.eval()
for param in self.parameters():
param.requires_grad = False
self.clip_dict[self.clip_name] = clip_model
self.transform = transforms.Compose(
[
transforms.Resize(224, transforms.InterpolationMode.BILINEAR, antialias=True),
transforms.CenterCrop(224), # crop a (224, 224) square
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
self.zero_embedding_radio = zero_embedding_radio
self.embedding_dim = clip_model.vision_embed_dim
self._move_flag = False
@property
def clip(self):
return self.clip_dict[self.clip_name]
def move(self):
if self._move_flag:
return
self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
self._move_flag = True
def unconditional_embedding(self, batch_size):
zero = torch.zeros(
batch_size,
self.clip.vision_model.embeddings.num_positions,
self.embedding_dim,
device=self.device,
dtype=self.clip.visual_projection.weight.dtype,
)
return zero
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
self.move()
if value_range is not None:
low, high = value_range
image = (image - low) / (high - low)
image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
z = self.clip.vision_model(self.transform(image)).last_hidden_state
if zero_embedding_radio > 0:
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
z = z * mask.to(z)
return z
def encode(self, image):
return self(image, zero_embedding_radio=self.zero_embedding_radio)
class MoECLIPImageEncoder(nn.Module):
def __init__(
self,
versions,
hidden_state_dim,
num_projection_vector=8,
zero_embedding_radio=0.1,
device="cuda",
precision="fp16",
normalize=False,
clip_max=0,
transform_type="base",
argument_p=0.2,
):
super().__init__()
self.device = torch.device(device)
self.hidden_state_dim = hidden_state_dim
self.zero_embedding_radio = zero_embedding_radio
self.num_projection_vector = num_projection_vector
self.dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[precision]
self.normalize = normalize
self.clip_max = clip_max
if transform_type == "base":
self.transform = transforms.Compose(
[
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
transforms.CenterCrop(224), # crop a (224, 224) square
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
elif transform_type == "crop_blur_resize":
self.transform = transforms.Compose(
[
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
transforms.CenterCrop(224), # crop a (224, 224) square
transforms.RandomApply(
transforms=[
transforms.RandomResizedCrop(
size=224,
scale=(0.8, 1.0),
ratio=(0.99, 1.01),
interpolation=transforms.InterpolationMode.BICUBIC,
),
],
p=argument_p,
),
transforms.RandomApply(
transforms=[
transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 5)),
],
p=argument_p,
),
transforms.RandomApply(
transforms=[
RandomResize(size=224, resize_radio=(0.2, 1)),
],
p=argument_p,
),
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
else:
raise ValueError(f"invalid {transform_type=}")
if isinstance(versions, str):
versions = (versions,)
# 如果直接把clips定位为当前类的子module1. 会在保存ckp时存无用的多个权重。 2. pl会调用to导致layer_norm的权重也被转换成fp16
clips = OrderedDict()
for v in versions:
# 因为clips不是子module直接指定device="cuda"会错误地导致clip模型权重都被放到cuda:0上。
clips[v], _ = clip.load(name=v, device="cpu", jit=False, download_root=None)
delattr(clips[v], "transformer")
clips[v].eval()
clips[v].requires_grad_(False)
self.clips_hidden_dim = sum(clips[v].ln_final.weight.size(0) for v in clips)
if self.num_projection_vector == 0:
self.projection = nn.Identity()
else:
self.projection = nn.Linear(self.clips_hidden_dim, hidden_state_dim * self.num_projection_vector, bias=True)
self.projection.to(dtype=self.dtype)
nn.init.normal_(self.projection.weight, std=self.clips_hidden_dim ** -0.5)
self.clips = clips
self._move_flag = False
def move(self):
if self._move_flag:
return
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.type(self.dtype)
if l.bias is not None:
l.bias.data = l.bias.data.type(self.dtype)
if isinstance(l, nn.MultiheadAttention):
for attr in [
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
"in_proj_bias",
"bias_k",
"bias_v",
]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.type(self.dtype)
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.type(self.dtype)
model.apply(_convert_weights_to_fp16)
for k in self.clips:
self.clips[k].to(self.device)
convert_weights(self.clips[k]) # fp32 -> self.dtype
self._move_flag = True
def unconditional_embedding(self, batch_size=None):
zero = torch.zeros(
batch_size,
self.clips_hidden_dim,
device=self.device,
dtype=self.dtype,
)
if self.num_projection_vector > 0:
zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
return zero
def convert_embedding(self, z):
if self.num_projection_vector > 0:
z = self.projection(z.type(self.projection.weight.dtype)).view(len(z), self.num_projection_vector, -1)
return z
def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
if value_range is not None:
low, high = value_range
image = (image - low) / (high - low)
image = self.transform(image)
with torch.no_grad():
embs = []
for v in self.clips:
x = self.clips[v].encode_image(image)
if self.normalize:
x = x / x.norm(p=2, dim=-1, keepdim=True) * (x.size(-1) ** 0.5)
# clip_max only works with normalization
if self.clip_max > 0:
x = x.clamp(-self.clip_max, self.clip_max)
embs.append(x)
z = torch.cat(embs, dim=-1)
if self.normalize:
z /= z.size(-1) ** 0.5
if zero_embedding_radio > 0:
mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
z = z + mask.to(z)
if self.num_projection_vector > 0:
z = self.projection(z).view(len(image), self.num_projection_vector, -1)
return z
def encode(self, image):
self.move()
return self(image, zero_embedding_radio=self.zero_embedding_radio)

View File

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
from .checkpoint import checkpoint

View File

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
"""
Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
"""
import torch
from typing import Callable, Iterable, Sequence, Union
def checkpoint(
func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
inputs: Sequence[torch.Tensor],
params: Iterable[torch.Tensor],
flag: bool,
use_deepspeed: bool = False
):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
:param use_deepspeed: if True, use deepspeed
"""
if flag:
if use_deepspeed:
import deepspeed
return deepspeed.checkpointing.checkpoint(func, *inputs)
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads

View File

@ -0,0 +1,218 @@
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
from typing import Optional
from .checkpoint import checkpoint
from .transformer_blocks import (
init_linear,
MLP,
MultiheadCrossAttention,
MultiheadAttention,
ResidualAttentionBlock
)
class AdaLayerNorm(nn.Module):
def __init__(self,
device: torch.device,
dtype: torch.dtype,
width: int):
super().__init__()
self.silu = nn.SiLU(inplace=True)
self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype)
self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype)
def forward(self, x, timestep):
emb = self.linear(timestep)
scale, shift = torch.chunk(emb, 2, dim=2)
x = self.layernorm(x) * (1 + scale) + shift
return x
class DitBlock(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
width: int,
heads: int,
context_dim: int,
qkv_bias: bool = False,
init_scale: float = 1.0,
use_checkpoint: bool = False
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.attn = MultiheadAttention(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias
)
self.ln_1 = AdaLayerNorm(device, dtype, width)
if context_dim is not None:
self.ln_2 = AdaLayerNorm(device, dtype, width)
self.cross_attn = MultiheadCrossAttention(
device=device,
dtype=dtype,
width=width,
heads=heads,
data_width=context_dim,
init_scale=init_scale,
qkv_bias=qkv_bias
)
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
self.ln_3 = AdaLayerNorm(device, dtype, width)
def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint)
def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
x = x + self.attn(self.ln_1(x, t))
if context is not None:
x = x + self.cross_attn(self.ln_2(x, t), context)
x = x + self.mlp(self.ln_3(x, t))
return x
class DiT(nn.Module):
def __init__(
self,
*,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
n_ctx: int,
width: int,
layers: int,
heads: int,
context_dim: int,
init_scale: float = 0.25,
qkv_bias: bool = False,
use_checkpoint: bool = False
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
DitBlock(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
context_dim=context_dim,
qkv_bias=qkv_bias,
init_scale=init_scale,
use_checkpoint=use_checkpoint
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
for block in self.resblocks:
x = block(x, t, context)
return x
class UNetDiffusionTransformer(nn.Module):
def __init__(
self,
*,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
n_ctx: int,
width: int,
layers: int,
heads: int,
init_scale: float = 0.25,
qkv_bias: bool = False,
skip_ln: bool = False,
use_checkpoint: bool = False
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.encoder = nn.ModuleList()
for _ in range(layers):
resblock = ResidualAttentionBlock(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
use_checkpoint=use_checkpoint
)
self.encoder.append(resblock)
self.middle_block = ResidualAttentionBlock(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
use_checkpoint=use_checkpoint
)
self.decoder = nn.ModuleList()
for _ in range(layers):
resblock = ResidualAttentionBlock(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
use_checkpoint=use_checkpoint
)
linear = nn.Linear(width * 2, width, device=device, dtype=dtype)
init_linear(linear, init_scale)
layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None
self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
def forward(self, x: torch.Tensor):
enc_outputs = []
for block in self.encoder:
x = block(x)
enc_outputs.append(x)
x = self.middle_block(x)
for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
x = torch.cat([enc_outputs.pop(), x], dim=-1)
x = linear(x)
if layer_norm is not None:
x = layer_norm(x)
x = resblock(x)
return x

View File

@ -0,0 +1,100 @@
import torch
import numpy as np
from typing import Union, List
class AbstractDistribution(object):
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
self.feat_dim = feat_dim
self.parameters = parameters
if isinstance(parameters, list):
self.mean = parameters[0]
self.logvar = parameters[1]
else:
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean)
def sample(self):
x = self.mean + self.std * torch.randn_like(self.mean)
return x
def kl(self, other=None, dims=(1, 2, 3)):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.mean(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=dims)
else:
return 0.5 * torch.mean(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=dims)
def nll(self, sample, dims=(1, 2, 3)):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

View File

@ -0,0 +1,213 @@
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import math
VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
[
sin(x[..., i]),
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i] # only present if include_input is True.
], here f_i is the frequency.
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
Args:
num_freqs (int): the number of frequencies, default is 6;
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
input_dim (int): the input dimension, default is 3;
include_input (bool): include the input tensor or not, default is True.
Attributes:
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
otherwise, it is input_dim * num_freqs * 2.
"""
def __init__(self,
num_freqs: int = 6,
logspace: bool = True,
input_dim: int = 3,
include_input: bool = True,
include_pi: bool = True) -> None:
"""The initialization"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(
num_freqs,
dtype=torch.float32
)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (num_freqs - 1),
num_freqs,
dtype=torch.float32
)
if include_pi:
frequencies *= torch.pi
self.register_buffer("frequencies", frequencies, persistent=False)
self.include_input = include_input
self.num_freqs = num_freqs
self.out_dim = self.get_dims(input_dim)
def get_dims(self, input_dim):
temp = 1 if self.include_input or self.num_freqs == 0 else 0
out_dim = input_dim * (self.num_freqs * 2 + temp)
return out_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Forward process.
Args:
x: tensor of shape [..., dim]
Returns:
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
where temp is 1 if include_input is True and 0 otherwise.
"""
if self.num_freqs > 0:
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
if self.include_input:
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
else:
return torch.cat((embed.sin(), embed.cos()), dim=-1)
else:
return x
class LearnedFourierEmbedder(nn.Module):
""" following @crowsonkb "s lead with learned sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, in_channels, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
per_channel_dim = half_dim // in_channels
self.weights = nn.Parameter(torch.randn(per_channel_dim))
def forward(self, x):
"""
Args:
x (torch.FloatTensor): [..., c]
Returns:
x (torch.FloatTensor): [..., d]
"""
# [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
return fouriered
class TriplaneLearnedFourierEmbedder(nn.Module):
def __init__(self, in_channels, dim):
super().__init__()
self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
self.out_dim = in_channels + dim
def forward(self, x):
yz_embed = self.yz_plane_embedder(x)
xz_embed = self.xz_plane_embedder(x)
xy_embed = self.xy_plane_embedder(x)
embed = yz_embed + xz_embed + xy_embed
return embed
def sequential_pos_embed(num_len, embed_dim):
assert embed_dim % 2 == 0
pos = torch.arange(num_len, dtype=torch.float32)
omega = torch.arange(embed_dim // 2, dtype=torch.float32)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return embeddings
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4,
num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16,
log2_hashmap_size=19, desired_resolution=None):
if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
return nn.Identity(), input_dim
elif embed_type == "fourier":
embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim,
logspace=True, include_input=True)
return embedder_obj, embedder_obj.out_dim
elif embed_type == "hashgrid":
raise NotImplementedError
elif embed_type == "sphere_harmonic":
raise NotImplementedError
else:
raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")

View File

@ -0,0 +1,286 @@
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from .checkpoint import checkpoint
def init_linear(l, stddev):
nn.init.normal_(l.weight, std=stddev)
if l.bias is not None:
nn.init.constant_(l.bias, 0.0)
class MultiheadAttention(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
width: int,
heads: int,
init_scale: float,
qkv_bias: bool,
flash: bool = False
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.heads = heads
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash)
init_linear(self.c_qkv, init_scale)
init_linear(self.c_proj, init_scale)
def forward(self, x):
x = self.c_qkv(x)
x = checkpoint(self.attention, (x,), (), True)
x = self.c_proj(x)
return x
class QKVMultiheadAttention(nn.Module):
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False):
super().__init__()
self.device = device
self.dtype = dtype
self.heads = heads
self.n_ctx = n_ctx
self.flash = flash
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.heads // 3
scale = 1 / math.sqrt(math.sqrt(attn_ch))
qkv = qkv.view(bs, n_ctx, self.heads, -1)
q, k, v = torch.split(qkv, attn_ch, dim=-1)
if self.flash:
out = F.scaled_dot_product_attention(q, k, v)
else:
weight = torch.einsum(
"bthc,bshc->bhts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
wdtype = weight.dtype
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
return out
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
width: int,
heads: int,
init_scale: float = 1.0,
qkv_bias: bool = True,
flash: bool = False,
use_checkpoint: bool = False
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.attn = MultiheadAttention(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash
)
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
def _forward(self, x: torch.Tensor):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
def forward(self, x: torch.Tensor):
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
class MultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
width: int,
heads: int,
init_scale: float,
qkv_bias: bool = True,
flash: bool = False,
n_data: Optional[int] = None,
data_width: Optional[int] = None,
):
super().__init__()
self.n_data = n_data
self.width = width
self.heads = heads
self.data_width = width if data_width is None else data_width
self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
self.attention = QKVMultiheadCrossAttention(
device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash
)
init_linear(self.c_q, init_scale)
init_linear(self.c_kv, init_scale)
init_linear(self.c_proj, init_scale)
def forward(self, x, data):
x = self.c_q(x)
data = self.c_kv(data)
x = checkpoint(self.attention, (x, data), (), True)
x = self.c_proj(x)
return x
class QKVMultiheadCrossAttention(nn.Module):
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
flash: bool = False, n_data: Optional[int] = None):
super().__init__()
self.device = device
self.dtype = dtype
self.heads = heads
self.n_data = n_data
self.flash = flash
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
scale = 1 / math.sqrt(math.sqrt(attn_ch))
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
if self.flash:
out = F.scaled_dot_product_attention(q, k, v)
else:
weight = torch.einsum(
"bthc,bshc->bhts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
wdtype = weight.dtype
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
return out
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
n_data: Optional[int] = None,
width: int,
heads: int,
data_width: Optional[int] = None,
init_scale: float = 0.25,
qkv_bias: bool = True,
flash: bool = False
):
super().__init__()
if data_width is None:
data_width = width
self.attn = MultiheadCrossAttention(
device=device,
dtype=dtype,
n_data=n_data,
width=width,
heads=heads,
data_width=data_width,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
)
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, data: torch.Tensor):
x = x + self.attn(self.ln_1(x), self.ln_2(data))
x = x + self.mlp(self.ln_3(x))
return x
class MLP(nn.Module):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
width: int,
init_scale: float):
super().__init__()
self.width = width
self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
self.gelu = nn.GELU()
init_linear(self.c_fc, init_scale)
init_linear(self.c_proj, init_scale)
def forward(self, x):
return self.c_proj(self.gelu(self.c_fc(x)))
class Transformer(nn.Module):
def __init__(
self,
*,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
n_ctx: int,
width: int,
layers: int,
heads: int,
init_scale: float = 0.25,
qkv_bias: bool = True,
flash: bool = False,
use_checkpoint: bool = False
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
use_checkpoint=use_checkpoint
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor):
for block in self.resblocks:
x = block(x)
return x

View File

@ -0,0 +1,308 @@
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
from typing import Optional
import warnings
from .checkpoint import checkpoint
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor | nn.Parameter, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
applied while sampling the normal with mean/std applied, therefore a, b args
should be adjusted to match the range of mean, std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
with torch.no_grad():
return _trunc_normal_(tensor, mean, std, a, b)
def init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
class MultiheadAttention(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
width: int,
heads: int,
qkv_bias: bool
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.heads = heads
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
def forward(self, x):
x = self.c_qkv(x)
x = checkpoint(self.attention, (x,), (), True)
x = self.c_proj(x)
return x
class QKVMultiheadAttention(nn.Module):
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
super().__init__()
self.device = device
self.dtype = dtype
self.heads = heads
self.n_ctx = n_ctx
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.heads // 3
scale = 1 / math.sqrt(attn_ch)
qkv = qkv.view(bs, n_ctx, self.heads, -1)
q, k, v = torch.split(qkv, attn_ch, dim=-1)
weight = torch.einsum("bthc,bshc->bhts", q, k) * scale
wdtype = weight.dtype
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int,
width: int,
heads: int,
qkv_bias: bool = True,
use_checkpoint: bool = False
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.attn = MultiheadAttention(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
qkv_bias=qkv_bias
)
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
self.mlp = MLP(device=device, dtype=dtype, width=width)
self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
def _forward(self, x: torch.Tensor):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
def forward(self, x: torch.Tensor):
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
class MultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
width: int,
heads: int,
qkv_bias: bool = True,
n_data: Optional[int] = None,
data_width: Optional[int] = None,
):
super().__init__()
self.n_data = n_data
self.width = width
self.heads = heads
self.data_width = width if data_width is None else data_width
self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
self.attention = QKVMultiheadCrossAttention(
device=device, dtype=dtype, heads=heads, n_data=n_data
)
def forward(self, x, data):
x = self.c_q(x)
data = self.c_kv(data)
x = checkpoint(self.attention, (x, data), (), True)
x = self.c_proj(x)
return x
class QKVMultiheadCrossAttention(nn.Module):
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: Optional[int] = None):
super().__init__()
self.device = device
self.dtype = dtype
self.heads = heads
self.n_data = n_data
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
scale = 1 / math.sqrt(attn_ch)
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
weight = torch.einsum("bthc,bshc->bhts", q, k) * scale
wdtype = weight.dtype
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
n_data: Optional[int] = None,
width: int,
heads: int,
data_width: Optional[int] = None,
qkv_bias: bool = True
):
super().__init__()
if data_width is None:
data_width = width
self.attn = MultiheadCrossAttention(
device=device,
dtype=dtype,
n_data=n_data,
width=width,
heads=heads,
data_width=data_width,
qkv_bias=qkv_bias
)
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
self.mlp = MLP(device=device, dtype=dtype, width=width)
self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, data: torch.Tensor):
x = x + self.attn(self.ln_1(x), self.ln_2(data))
x = x + self.mlp(self.ln_3(x))
return x
class MLP(nn.Module):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
width: int):
super().__init__()
self.width = width
self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
self.gelu = nn.GELU()
def forward(self, x):
return self.c_proj(self.gelu(self.c_fc(x)))
class Transformer(nn.Module):
def __init__(
self,
*,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
n_ctx: int,
width: int,
layers: int,
heads: int,
qkv_bias: bool = True,
use_checkpoint: bool = False
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
heads=heads,
qkv_bias=qkv_bias,
use_checkpoint=use_checkpoint
)
for _ in range(layers)
]
)
self.apply(init_weights)
def forward(self, x: torch.Tensor):
for block in self.resblocks:
x = block(x)
return x

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,373 @@
# -*- coding: utf-8 -*-
from typing import List, Tuple, Dict, Optional
from omegaconf import DictConfig
import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
import pytorch_lightning as pl
from typing import Union
from functools import partial
from ...utils import instantiate_from_config
from .inference_utils import extract_geometry
from .tsal_base import (
AlignedShapeAsLatentModule,
ShapeAsLatentModule,
Latent2MeshOutput,
AlignedMeshOutput
)
class AlignedShapeAsLatentPLModule(pl.LightningModule):
def __init__(self, *,
shape_module_cfg,
aligned_module_cfg,
loss_cfg,
optimizer_cfg: Optional[DictConfig] = None,
ckpt_path: Optional[str] = None,
ignore_keys: Union[Tuple[str], List[str]] = ()):
super().__init__()
shape_model: ShapeAsLatentModule = instantiate_from_config(
shape_module_cfg, device=None, dtype=None
)
self.model: AlignedShapeAsLatentModule = instantiate_from_config(
aligned_module_cfg, shape_model=shape_model
)
self.loss = instantiate_from_config(loss_cfg)
self.optimizer_cfg = optimizer_cfg
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.save_hyperparameters()
def set_shape_model_only(self):
self.model.set_shape_model_only()
@property
def latent_shape(self):
return self.model.shape_model.latent_shape
@property
def zero_rank(self):
if self._trainer:
zero_rank = self.trainer.local_rank == 0
else:
zero_rank = True
return zero_rank
def init_from_ckpt(self, path, ignore_keys=()):
state_dict = torch.load(path, map_location="cpu")["state_dict"]
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def configure_optimizers(self) -> Tuple[List, List]:
lr = self.learning_rate
trainable_parameters = list(self.model.parameters())
if self.optimizer_cfg is None:
optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
schedulers = []
else:
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
scheduler_func = instantiate_from_config(
self.optimizer_cfg.scheduler,
max_decay_steps=self.trainer.max_steps,
lr_max=lr
)
scheduler = {
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
"interval": "step",
"frequency": 1
}
optimizers = [optimizer]
schedulers = [scheduler]
return optimizers, schedulers
def forward(self,
surface: torch.FloatTensor,
image: torch.FloatTensor,
text: torch.FloatTensor,
volume_queries: torch.FloatTensor):
"""
Args:
surface (torch.FloatTensor):
image (torch.FloatTensor):
text (torch.FloatTensor):
volume_queries (torch.FloatTensor):
Returns:
"""
embed_outputs, shape_z = self.model(surface, image, text)
shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
latents = self.model.shape_model.decode(shape_zq)
logits = self.model.shape_model.query_geometry(volume_queries, latents)
return embed_outputs, logits, posterior
def encode(self, surface: torch.FloatTensor, sample_posterior=True):
pc = surface[..., 0:3]
feats = surface[..., 3:6]
shape_embed, shape_zq, posterior = self.model.shape_model.encode(
pc=pc, feats=feats, sample_posterior=sample_posterior
)
return shape_zq
def encode_latents(self, surface: torch.FloatTensor):
pc = surface[..., 0:3]
feats = surface[..., 3:6]
shape_embed, shape_latents = self.model.shape_model.encode_latents(
pc=pc, feats=feats
)
shape_embed = shape_embed.unsqueeze(1)
assert shape_embed.shape[1] == 1 and shape_latents.shape[1] == 256
cat_latents = torch.cat([shape_embed, shape_latents], dim=1)
return cat_latents
def to_shape_latents(self, latents):
shape_zq, posterior = self.model.shape_model.encode_kl_embed(latents, sample_posterior = False)
return self.model.shape_model.decode(shape_zq)
def decode(self,
z_q,
bounds: Union[Tuple[float], List[float], float] = 1.1,
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim]
outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
return outputs
def training_step(self, batch: Dict[str, torch.FloatTensor],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
- image (torch.FloatTensor): [bs, 3, 224, 224]
- text (torch.FloatTensor): [bs, num_templates, 77]
- geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
surface = batch["surface"]
image = batch["image"]
text = batch["text"]
volume_queries = batch["geo_points"][..., 0:3]
shape_labels = batch["geo_points"][..., -1]
embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
aeloss, log_dict_ae = self.loss(
**embed_outputs,
posteriors=posteriors,
shape_logits=shape_logits,
shape_labels=shape_labels,
split="train"
)
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
sync_dist=False, rank_zero_only=True)
return aeloss
def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
surface = batch["surface"]
image = batch["image"]
text = batch["text"]
volume_queries = batch["geo_points"][..., 0:3]
shape_labels = batch["geo_points"][..., -1]
embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
aeloss, log_dict_ae = self.loss(
**embed_outputs,
posteriors=posteriors,
shape_logits=shape_logits,
shape_labels=shape_labels,
split="val"
)
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
sync_dist=False, rank_zero_only=True)
return aeloss
def visual_alignment(self,
surface: torch.FloatTensor,
image: torch.FloatTensor,
text: torch.FloatTensor,
description: Optional[List[str]] = None,
bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
octree_depth: int = 7,
num_chunks: int = 10000) -> List[AlignedMeshOutput]:
"""
Args:
surface:
image:
text:
description:
bounds:
octree_depth:
num_chunks:
Returns:
mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list.
"""
outputs = []
device = surface.device
bs = surface.shape[0]
embed_outputs, shape_z = self.model(surface, image, text)
# calculate the similarity
image_embed = embed_outputs["image_embed"]
text_embed = embed_outputs["text_embed"]
shape_embed = embed_outputs["shape_embed"]
# normalized features
shape_embed = F.normalize(shape_embed, dim=-1, p=2)
text_embed = F.normalize(text_embed, dim=-1, p=2)
image_embed = F.normalize(image_embed, dim=-1, p=2)
# B x B
shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1)
# B x B
shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1)
# shape reconstruction
shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
latents = self.model.shape_model.decode(shape_zq)
geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
# 2. decode geometry
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=device,
batch_size=bs,
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
disable=not self.zero_rank
)
# 3. decode texture
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
if not is_surface:
outputs.append(None)
continue
out = AlignedMeshOutput()
out.mesh_v = mesh_v
out.mesh_f = mesh_f
out.surface = surface[i].cpu().numpy()
out.image = image[i].cpu().numpy()
if description is not None:
out.text = description[i]
out.shape_text_similarity = shape_text_similarity[i, i]
out.shape_image_similarity = shape_image_similarity[i, i]
outputs.append(out)
return outputs
def latent2mesh(self,
latents: torch.FloatTensor,
bounds: Union[Tuple[float], List[float], float] = 1.1,
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
"""
Args:
latents: [bs, num_latents, dim]
bounds:
octree_depth:
num_chunks:
Returns:
mesh_outputs (List[MeshOutput]): the mesh outputs list.
"""
outputs = []
geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
# 2. decode geometry
device = latents.device
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=device,
batch_size=len(latents),
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
disable=not self.zero_rank
)
# 3. decode texture
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
if not is_surface:
outputs.append(None)
continue
out = Latent2MeshOutput()
out.mesh_v = mesh_v
out.mesh_f = mesh_f
outputs.append(out)
return outputs

View File

@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
import torch
from torch import nn
from einops import rearrange
from transformers import CLIPModel
from .tsal_base import AlignedShapeAsLatentModule
class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule):
def __init__(self, *,
shape_model,
projection_dim=768):
super().__init__()
self.shape_model = shape_model
self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, projection_dim))
nn.init.normal_(self.shape_projection, std=projection_dim ** -0.5)
def set_shape_model_only(self):
self.clip_model = None
def encode_shape_embed(self, surface, return_latents: bool = False):
"""
Args:
surface (torch.FloatTensor): [bs, n, 3 + c]
return_latents (bool):
Returns:
x (torch.FloatTensor): [bs, projection_dim]
shape_latents (torch.FloatTensor): [bs, m, d]
"""
pc = surface[..., 0:3]
feats = surface[..., 3:]
shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats)
x = shape_embed @ self.shape_projection
if return_latents:
return x, shape_latents
else:
return x
def encode_image_embed(self, image):
"""
Args:
image (torch.FloatTensor): [bs, 3, h, w]
Returns:
x (torch.FloatTensor): [bs, projection_dim]
"""
x = self.clip_model.get_image_features(image)
return x
def encode_text_embed(self, text):
x = self.clip_model.get_text_features(text)
return x
def forward(self, surface, image, text):
"""
Args:
surface (torch.FloatTensor):
image (torch.FloatTensor): [bs, 3, 224, 224]
text (torch.LongTensor): [bs, num_templates, 77]
Returns:
embed_outputs (dict): the embedding outputs, and it contains:
- image_embed (torch.FloatTensor):
- text_embed (torch.FloatTensor):
- shape_embed (torch.FloatTensor):
- logit_scale (float):
"""
# # text embedding
# text_embed_all = []
# for i in range(text.shape[0]):
# text_for_one_sample = text[i]
# text_embed = self.encode_text_embed(text_for_one_sample)
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
# text_embed = text_embed.mean(dim=0)
# text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
# text_embed_all.append(text_embed)
# text_embed_all = torch.stack(text_embed_all)
b = text.shape[0]
text_tokens = rearrange(text, "b t l -> (b t) l")
text_embed = self.encode_text_embed(text_tokens)
text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
text_embed = text_embed.mean(dim=1)
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
# image embedding
image_embed = self.encode_image_embed(image)
# shape embedding
shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True)
embed_outputs = {
"image_embed": image_embed,
"text_embed": text_embed,
"shape_embed": shape_embed,
"logit_scale": self.clip_model.logit_scale.exp()
}
return embed_outputs, shape_latents

View File

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
import torch
from tqdm import tqdm
from einops import repeat
import numpy as np
from typing import Callable, Tuple, List, Union, Optional
from skimage import measure
from ...graphics.primitives import generate_dense_grid_points
@torch.no_grad()
def extract_geometry(geometric_func: Callable,
device: torch.device,
batch_size: int = 1,
bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
octree_depth: int = 7,
num_chunks: int = 10000,
disable: bool = True):
"""
Args:
geometric_func:
device:
bounds:
octree_depth:
batch_size:
num_chunks:
disable:
Returns:
"""
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min = np.array(bounds[0:3])
bbox_max = np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_depth=octree_depth,
indexing="ij"
)
xyz_samples = torch.FloatTensor(xyz_samples)
batch_logits = []
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
desc="Implicit Function:", disable=disable, leave=False):
queries = xyz_samples[start: start + num_chunks, :].to(device)
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
logits = geometric_func(batch_queries)
batch_logits.append(logits.cpu())
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy()
mesh_v_f = []
has_surface = np.zeros((batch_size,), dtype=np.bool_)
for i in range(batch_size):
try:
vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner")
vertices = vertices / grid_size * bbox_size + bbox_min
# vertices[:, [0, 1]] = vertices[:, [1, 0]]
mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces)))
has_surface[i] = True
except ValueError:
mesh_v_f.append((None, None))
has_surface[i] = False
except RuntimeError:
mesh_v_f.append((None, None))
has_surface[i] = False
return mesh_v_f, has_surface

View File

@ -0,0 +1,303 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict
from ..modules.distributions import DiagonalGaussianDistribution
from ...utils.eval import compute_psnr
from ...utils import misc
class KLNearFar(nn.Module):
def __init__(self,
near_weight: float = 0.1,
kl_weight: float = 1.0,
num_near_samples: Optional[int] = None):
super().__init__()
self.near_weight = near_weight
self.kl_weight = kl_weight
self.num_near_samples = num_near_samples
self.geo_criterion = nn.BCEWithLogitsLoss()
def forward(self,
posteriors: Optional[DiagonalGaussianDistribution],
logits: torch.FloatTensor,
labels: torch.FloatTensor,
split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
"""
Args:
posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
split (str):
**kwargs:
Returns:
loss (torch.Tensor): (,)
log (dict):
"""
if self.num_near_samples is None:
num_vol = logits.shape[1] // 2
else:
num_vol = logits.shape[1] - self.num_near_samples
vol_logits = logits[:, 0:num_vol]
vol_labels = labels[:, 0:num_vol]
near_logits = logits[:, num_vol:]
near_labels = labels[:, num_vol:]
# occupancy loss
# vol_bce = self.geo_criterion(vol_logits, vol_labels)
# near_bce = self.geo_criterion(near_logits, near_labels)
vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
if posteriors is None:
kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
else:
kl_loss = posteriors.kl(dims=(1, 2))
kl_loss = torch.mean(kl_loss)
loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight
with torch.no_grad():
preds = logits >= 0
accuracy = (preds == labels).float()
accuracy = accuracy.mean()
pos_ratio = torch.mean(labels)
log = {
"{}/total_loss".format(split): loss.clone().detach(),
"{}/near".format(split): near_bce.detach(),
"{}/far".format(split): vol_bce.detach(),
"{}/kl".format(split): kl_loss.detach(),
"{}/accuracy".format(split): accuracy,
"{}/pos_ratio".format(split): pos_ratio
}
if posteriors is not None:
log[f"{split}/mean"] = posteriors.mean.mean().detach()
log[f"{split}/std_mean"] = posteriors.std.mean().detach()
log[f"{split}/std_max"] = posteriors.std.max().detach()
return loss, log
class KLNearFarColor(nn.Module):
def __init__(self,
near_weight: float = 0.1,
kl_weight: float = 1.0,
color_weight: float = 1.0,
color_criterion: str = "mse",
num_near_samples: Optional[int] = None):
super().__init__()
self.color_weight = color_weight
self.near_weight = near_weight
self.kl_weight = kl_weight
self.num_near_samples = num_near_samples
if color_criterion == "mse":
self.color_criterion = nn.MSELoss()
elif color_criterion == "l1":
self.color_criterion = nn.L1Loss()
else:
raise ValueError(f"{color_criterion} must be [`mse`, `l1`].")
self.geo_criterion = nn.BCEWithLogitsLoss()
def forward(self,
posteriors: Optional[DiagonalGaussianDistribution],
logits: torch.FloatTensor,
labels: torch.FloatTensor,
pred_colors: torch.FloatTensor,
gt_colors: torch.FloatTensor,
split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
"""
Args:
posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
pred_colors (torch.FloatTensor): [B, M, 3]
gt_colors (torch.FloatTensor): [B, M, 3]
split (str):
**kwargs:
Returns:
loss (torch.Tensor): (,)
log (dict):
"""
if self.num_near_samples is None:
num_vol = logits.shape[1] // 2
else:
num_vol = logits.shape[1] - self.num_near_samples
vol_logits = logits[:, 0:num_vol]
vol_labels = labels[:, 0:num_vol]
near_logits = logits[:, num_vol:]
near_labels = labels[:, num_vol:]
# occupancy loss
# vol_bce = self.geo_criterion(vol_logits, vol_labels)
# near_bce = self.geo_criterion(near_logits, near_labels)
vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
# surface color loss
color = self.color_criterion(pred_colors, gt_colors)
if posteriors is None:
kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device)
else:
kl_loss = posteriors.kl(dims=(1, 2))
kl_loss = torch.mean(kl_loss)
loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight
with torch.no_grad():
preds = logits >= 0
accuracy = (preds == labels).float()
accuracy = accuracy.mean()
psnr = compute_psnr(pred_colors, gt_colors)
log = {
"{}/total_loss".format(split): loss.clone().detach(),
"{}/near".format(split): near_bce.detach(),
"{}/far".format(split): vol_bce.detach(),
"{}/color".format(split): color.detach(),
"{}/kl".format(split): kl_loss.detach(),
"{}/psnr".format(split): psnr.detach(),
"{}/accuracy".format(split): accuracy
}
return loss, log
class ContrastKLNearFar(nn.Module):
def __init__(self,
contrast_weight: float = 1.0,
near_weight: float = 0.1,
kl_weight: float = 1.0,
num_near_samples: Optional[int] = None):
super().__init__()
self.labels = None
self.last_local_batch_size = None
self.contrast_weight = contrast_weight
self.near_weight = near_weight
self.kl_weight = kl_weight
self.num_near_samples = num_near_samples
self.geo_criterion = nn.BCEWithLogitsLoss()
def forward(self,
shape_embed: torch.FloatTensor,
text_embed: torch.FloatTensor,
image_embed: torch.FloatTensor,
logit_scale: torch.FloatTensor,
posteriors: Optional[DiagonalGaussianDistribution],
shape_logits: torch.FloatTensor,
shape_labels: torch.FloatTensor,
split: Optional[str] = "train", **kwargs):
local_batch_size = shape_embed.size(0)
if local_batch_size != self.last_local_batch_size:
self.labels = local_batch_size * misc.get_rank() + torch.arange(
local_batch_size, device=shape_embed.device
).long()
self.last_local_batch_size = local_batch_size
# normalized features
shape_embed = F.normalize(shape_embed, dim=-1, p=2)
text_embed = F.normalize(text_embed, dim=-1, p=2)
image_embed = F.normalize(image_embed, dim=-1, p=2)
# gather features from all GPUs
shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch(
[shape_embed, text_embed, image_embed]
)
# cosine similarity as logits
logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t()
logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t()
logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t()
logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t()
contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) +
F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \
(F.cross_entropy(logits_per_shape_image, self.labels) +
F.cross_entropy(logits_per_image_shape, self.labels)) / 2
# shape reconstruction
if self.num_near_samples is None:
num_vol = shape_logits.shape[1] // 2
else:
num_vol = shape_logits.shape[1] - self.num_near_samples
vol_logits = shape_logits[:, 0:num_vol]
vol_labels = shape_labels[:, 0:num_vol]
near_logits = shape_logits[:, num_vol:]
near_labels = shape_labels[:, num_vol:]
# occupancy loss
vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
if posteriors is None:
kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
else:
kl_loss = posteriors.kl(dims=(1, 2))
kl_loss = torch.mean(kl_loss)
loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight
# compute accuracy
with torch.no_grad():
pred = torch.argmax(logits_per_shape_text, dim=-1)
correct = pred.eq(self.labels).sum()
shape_text_acc = 100 * correct / local_batch_size
pred = torch.argmax(logits_per_shape_image, dim=-1)
correct = pred.eq(self.labels).sum()
shape_image_acc = 100 * correct / local_batch_size
preds = shape_logits >= 0
accuracy = (preds == shape_labels).float()
accuracy = accuracy.mean()
log = {
"{}/contrast".format(split): contrast_loss.clone().detach(),
"{}/near".format(split): near_bce.detach(),
"{}/far".format(split): vol_bce.detach(),
"{}/kl".format(split): kl_loss.detach(),
"{}/shape_text_acc".format(split): shape_text_acc,
"{}/shape_image_acc".format(split): shape_image_acc,
"{}/total_loss".format(split): loss.clone().detach(),
"{}/accuracy".format(split): accuracy,
}
if posteriors is not None:
log[f"{split}/mean"] = posteriors.mean.mean().detach()
log[f"{split}/std_mean"] = posteriors.std.mean().detach()
log[f"{split}/std_max"] = posteriors.std.max().detach()
return loss, log

View File

@ -0,0 +1,423 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from typing import Optional
from einops import repeat
import math
from ..modules import checkpoint
from ..modules.embedder import FourierEmbedder
from ..modules.distributions import DiagonalGaussianDistribution
from ..modules.transformer_blocks import (
ResidualCrossAttentionBlock,
Transformer
)
from .tsal_base import ShapeAsLatentModule
class CrossAttentionEncoder(nn.Module):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
num_latents: int,
fourier_embedder: FourierEmbedder,
point_feats: int,
width: int,
heads: int,
layers: int,
init_scale: float = 0.25,
qkv_bias: bool = True,
flash: bool = False,
use_ln_post: bool = False,
use_checkpoint: bool = False):
super().__init__()
self.use_checkpoint = use_checkpoint
self.num_latents = num_latents
self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
self.fourier_embedder = fourier_embedder
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
self.cross_attn = ResidualCrossAttentionBlock(
device=device,
dtype=dtype,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
)
self.self_attn = Transformer(
device=device,
dtype=dtype,
n_ctx=num_latents,
width=width,
layers=layers,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
use_checkpoint=False
)
if use_ln_post:
self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
else:
self.ln_post = None
def _forward(self, pc, feats):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, C]
Returns:
"""
bs = pc.shape[0]
data = self.fourier_embedder(pc)
if feats is not None:
data = torch.cat([data, feats], dim=-1)
data = self.input_proj(data)
query = repeat(self.query, "m c -> b m c", b=bs)
latents = self.cross_attn(query, data)
latents = self.self_attn(latents)
if self.ln_post is not None:
latents = self.ln_post(latents)
return latents, pc
def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, C]
Returns:
dict
"""
return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
class CrossAttentionDecoder(nn.Module):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
num_latents: int,
out_channels: int,
fourier_embedder: FourierEmbedder,
width: int,
heads: int,
init_scale: float = 0.25,
qkv_bias: bool = True,
flash: bool = False,
use_checkpoint: bool = False):
super().__init__()
self.use_checkpoint = use_checkpoint
self.fourier_embedder = fourier_embedder
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
self.cross_attn_decoder = ResidualCrossAttentionBlock(
device=device,
dtype=dtype,
n_data=num_latents,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
queries = self.query_proj(self.fourier_embedder(queries))
x = self.cross_attn_decoder(queries, latents)
x = self.ln_post(x)
x = self.output_proj(x)
return x
def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)
class ShapeAsLatentPerceiver(ShapeAsLatentModule):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
num_latents: int,
point_feats: int = 0,
embed_dim: int = 0,
num_freqs: int = 8,
include_pi: bool = True,
width: int,
heads: int,
num_encoder_layers: int,
num_decoder_layers: int,
init_scale: float = 0.25,
qkv_bias: bool = True,
flash: bool = False,
use_ln_post: bool = False,
use_checkpoint: bool = False):
super().__init__()
self.use_checkpoint = use_checkpoint
self.num_latents = num_latents
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
init_scale = init_scale * math.sqrt(1.0 / width)
self.encoder = CrossAttentionEncoder(
device=device,
dtype=dtype,
fourier_embedder=self.fourier_embedder,
num_latents=num_latents,
point_feats=point_feats,
width=width,
heads=heads,
layers=num_encoder_layers,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
use_ln_post=use_ln_post,
use_checkpoint=use_checkpoint
)
self.embed_dim = embed_dim
if embed_dim > 0:
# VAE embed
self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype)
self.latent_shape = (num_latents, embed_dim)
else:
self.latent_shape = (num_latents, width)
self.transformer = Transformer(
device=device,
dtype=dtype,
n_ctx=num_latents,
width=width,
layers=num_decoder_layers,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
use_checkpoint=use_checkpoint
)
# geometry decoder
self.geo_decoder = CrossAttentionDecoder(
device=device,
dtype=dtype,
fourier_embedder=self.fourier_embedder,
out_channels=1,
num_latents=num_latents,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
use_checkpoint=use_checkpoint
)
def encode(self,
pc: torch.FloatTensor,
feats: Optional[torch.FloatTensor] = None,
sample_posterior: bool = True):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, C]
sample_posterior (bool):
Returns:
latents (torch.FloatTensor)
center_pos (torch.FloatTensor or None):
posterior (DiagonalGaussianDistribution or None):
"""
latents, center_pos = self.encoder(pc, feats)
posterior = None
if self.embed_dim > 0:
moments = self.pre_kl(latents)
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
if sample_posterior:
latents = posterior.sample()
else:
latents = posterior.mode()
return latents, center_pos, posterior
def decode(self, latents: torch.FloatTensor):
latents = self.post_kl(latents)
return self.transformer(latents)
def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
logits = self.geo_decoder(queries, latents).squeeze(-1)
return logits
def forward(self,
pc: torch.FloatTensor,
feats: torch.FloatTensor,
volume_queries: torch.FloatTensor,
sample_posterior: bool = True):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, C]
volume_queries (torch.FloatTensor): [B, P, 3]
sample_posterior (bool):
Returns:
logits (torch.FloatTensor): [B, P]
center_pos (torch.FloatTensor): [B, M, 3]
posterior (DiagonalGaussianDistribution or None).
"""
latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
latents = self.decode(latents)
logits = self.query_geometry(volume_queries, latents)
return logits, center_pos, posterior
class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver):
def __init__(self, *,
device: Optional[torch.device],
dtype: Optional[torch.dtype],
num_latents: int,
point_feats: int = 0,
embed_dim: int = 0,
num_freqs: int = 8,
include_pi: bool = True,
width: int,
heads: int,
num_encoder_layers: int,
num_decoder_layers: int,
init_scale: float = 0.25,
qkv_bias: bool = True,
flash: bool = False,
use_ln_post: bool = False,
use_checkpoint: bool = False):
super().__init__(
device=device,
dtype=dtype,
num_latents=1 + num_latents,
point_feats=point_feats,
embed_dim=embed_dim,
num_freqs=num_freqs,
include_pi=include_pi,
width=width,
heads=heads,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
init_scale=init_scale,
qkv_bias=qkv_bias,
flash=flash,
use_ln_post=use_ln_post,
use_checkpoint=use_checkpoint
)
self.width = width
def encode(self,
pc: torch.FloatTensor,
feats: Optional[torch.FloatTensor] = None,
sample_posterior: bool = True):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, c]
sample_posterior (bool):
Returns:
shape_embed (torch.FloatTensor)
kl_embed (torch.FloatTensor):
posterior (DiagonalGaussianDistribution or None):
"""
shape_embed, latents = self.encode_latents(pc, feats)
kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior)
return shape_embed, kl_embed, posterior
def encode_latents(self,
pc: torch.FloatTensor,
feats: Optional[torch.FloatTensor] = None):
x, _ = self.encoder(pc, feats)
shape_embed = x[:, 0]
latents = x[:, 1:]
return shape_embed, latents
def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
posterior = None
if self.embed_dim > 0:
moments = self.pre_kl(latents)
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
if sample_posterior:
kl_embed = posterior.sample()
else:
kl_embed = posterior.mode()
else:
kl_embed = latents
return kl_embed, posterior
def forward(self,
pc: torch.FloatTensor,
feats: torch.FloatTensor,
volume_queries: torch.FloatTensor,
sample_posterior: bool = True):
"""
Args:
pc (torch.FloatTensor): [B, N, 3]
feats (torch.FloatTensor or None): [B, N, C]
volume_queries (torch.FloatTensor): [B, P, 3]
sample_posterior (bool):
Returns:
shape_embed (torch.FloatTensor): [B, projection_dim]
logits (torch.FloatTensor): [B, M]
posterior (DiagonalGaussianDistribution or None).
"""
shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
latents = self.decode(kl_embed)
logits = self.query_geometry(volume_queries, latents)
return shape_embed, logits, posterior

View File

@ -0,0 +1,290 @@
# -*- coding: utf-8 -*-
from typing import List, Tuple, Dict, Optional
from omegaconf import DictConfig
import torch
from torch.optim import lr_scheduler
import pytorch_lightning as pl
from typing import Union
from functools import partial
from ...utils import instantiate_from_config
from .inference_utils import extract_geometry
from .tsal_base import (
ShapeAsLatentModule,
Latent2MeshOutput,
Point2MeshOutput
)
class ShapeAsLatentPLModule(pl.LightningModule):
def __init__(self, *,
module_cfg,
loss_cfg,
optimizer_cfg: Optional[DictConfig] = None,
ckpt_path: Optional[str] = None,
ignore_keys: Union[Tuple[str], List[str]] = ()):
super().__init__()
self.sal: ShapeAsLatentModule = instantiate_from_config(module_cfg, device=None, dtype=None)
self.loss = instantiate_from_config(loss_cfg)
self.optimizer_cfg = optimizer_cfg
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.save_hyperparameters()
@property
def latent_shape(self):
return self.sal.latent_shape
@property
def zero_rank(self):
if self._trainer:
zero_rank = self.trainer.local_rank == 0
else:
zero_rank = True
return zero_rank
def init_from_ckpt(self, path, ignore_keys=()):
state_dict = torch.load(path, map_location="cpu")["state_dict"]
keys = list(state_dict.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del state_dict[k]
missing, unexpected = self.load_state_dict(state_dict, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def configure_optimizers(self) -> Tuple[List, List]:
lr = self.learning_rate
# optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)]
# optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
if self.optimizer_cfg is None:
optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
schedulers = []
else:
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=self.sal.parameters())
scheduler_func = instantiate_from_config(
self.optimizer_cfg.scheduler,
max_decay_steps=self.trainer.max_steps,
lr_max=lr
)
scheduler = {
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
"interval": "step",
"frequency": 1
}
optimizers = [optimizer]
schedulers = [scheduler]
return optimizers, schedulers
def forward(self,
pc: torch.FloatTensor,
feats: torch.FloatTensor,
volume_queries: torch.FloatTensor):
logits, center_pos, posterior = self.sal(pc, feats, volume_queries)
return posterior, logits
def encode(self, surface: torch.FloatTensor, sample_posterior=True):
pc = surface[..., 0:3]
feats = surface[..., 3:6]
latents, center_pos, posterior = self.sal.encode(
pc=pc, feats=feats, sample_posterior=sample_posterior
)
return latents
def decode(self,
z_q,
bounds: Union[Tuple[float], List[float], float] = 1.1,
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
latents = self.sal.decode(z_q) # latents: [bs, num_latents, dim]
outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
return outputs
def training_step(self, batch: Dict[str, torch.FloatTensor],
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
"""
Args:
batch (dict): the batch sample, and it contains:
- surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
- geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
batch_idx (int):
optimizer_idx (int):
Returns:
loss (torch.FloatTensor):
"""
pc = batch["surface"][..., 0:3]
feats = batch["surface"][..., 3:]
volume_queries = batch["geo_points"][..., 0:3]
volume_labels = batch["geo_points"][..., -1]
posterior, logits = self(
pc=pc, feats=feats, volume_queries=volume_queries
)
aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="train")
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
sync_dist=False, rank_zero_only=True)
return aeloss
def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
pc = batch["surface"][..., 0:3]
feats = batch["surface"][..., 3:]
volume_queries = batch["geo_points"][..., 0:3]
volume_labels = batch["geo_points"][..., -1]
posterior, logits = self(
pc=pc, feats=feats, volume_queries=volume_queries,
)
aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="val")
self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
sync_dist=False, rank_zero_only=True)
return aeloss
def point2mesh(self,
pc: torch.FloatTensor,
feats: torch.FloatTensor,
bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Point2MeshOutput]:
"""
Args:
pc:
feats:
bounds:
octree_depth:
num_chunks:
Returns:
mesh_outputs (List[MeshOutput]): the mesh outputs list.
"""
outputs = []
device = pc.device
bs = pc.shape[0]
# 1. point encoder + latents transformer
latents, center_pos, posterior = self.sal.encode(pc, feats)
latents = self.sal.decode(latents) # latents: [bs, num_latents, dim]
geometric_func = partial(self.sal.query_geometry, latents=latents)
# 2. decode geometry
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=device,
batch_size=bs,
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
disable=not self.zero_rank
)
# 3. decode texture
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
if not is_surface:
outputs.append(None)
continue
out = Point2MeshOutput()
out.mesh_v = mesh_v
out.mesh_f = mesh_f
out.pc = torch.cat([pc[i], feats[i]], dim=-1).cpu().numpy()
if center_pos is not None:
out.center = center_pos[i].cpu().numpy()
outputs.append(out)
return outputs
def latent2mesh(self,
latents: torch.FloatTensor,
bounds: Union[Tuple[float], List[float], float] = 1.1,
octree_depth: int = 7,
num_chunks: int = 10000) -> List[Latent2MeshOutput]:
"""
Args:
latents: [bs, num_latents, dim]
bounds:
octree_depth:
num_chunks:
Returns:
mesh_outputs (List[MeshOutput]): the mesh outputs list.
"""
outputs = []
geometric_func = partial(self.sal.query_geometry, latents=latents)
# 2. decode geometry
device = latents.device
mesh_v_f, has_surface = extract_geometry(
geometric_func=geometric_func,
device=device,
batch_size=len(latents),
bounds=bounds,
octree_depth=octree_depth,
num_chunks=num_chunks,
disable=not self.zero_rank
)
# 3. decode texture
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
if not is_surface:
outputs.append(None)
continue
out = Latent2MeshOutput()
out.mesh_v = mesh_v
out.mesh_f = mesh_f
outputs.append(out)
return outputs

View File

@ -0,0 +1,121 @@
# -*- coding: utf-8 -*-
import torch.nn as nn
from typing import Tuple, List, Optional
import pytorch_lightning as pl
class Point2MeshOutput(object):
def __init__(self):
self.mesh_v = None
self.mesh_f = None
self.center = None
self.pc = None
class Latent2MeshOutput(object):
def __init__(self):
self.mesh_v = None
self.mesh_f = None
class AlignedMeshOutput(object):
def __init__(self):
self.mesh_v = None
self.mesh_f = None
self.surface = None
self.image = None
self.text: Optional[str] = None
self.shape_text_similarity: Optional[float] = None
self.shape_image_similarity: Optional[float] = None
class ShapeAsLatentPLModule(pl.LightningModule):
latent_shape: Tuple[int]
def encode(self, surface, *args, **kwargs):
raise NotImplementedError
def decode(self, z_q, *args, **kwargs):
raise NotImplementedError
def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
raise NotImplementedError
def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
raise NotImplementedError
class ShapeAsLatentModule(nn.Module):
latent_shape: Tuple[int, int]
def __init__(self, *args, **kwargs):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
def decode(self, *args, **kwargs):
raise NotImplementedError
def query_geometry(self, *args, **kwargs):
raise NotImplementedError
class AlignedShapeAsLatentPLModule(pl.LightningModule):
latent_shape: Tuple[int]
def set_shape_model_only(self):
raise NotImplementedError
def encode(self, surface, *args, **kwargs):
raise NotImplementedError
def decode(self, z_q, *args, **kwargs):
raise NotImplementedError
def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
raise NotImplementedError
def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
raise NotImplementedError
class AlignedShapeAsLatentModule(nn.Module):
shape_model: ShapeAsLatentModule
latent_shape: Tuple[int, int]
def __init__(self, *args, **kwargs):
super().__init__()
def set_shape_model_only(self):
raise NotImplementedError
def encode_image_embed(self, *args, **kwargs):
raise NotImplementedError
def encode_text_embed(self, *args, **kwargs):
raise NotImplementedError
def encode_shape_embed(self, *args, **kwargs):
raise NotImplementedError
class TexturedShapeAsLatentModule(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
def decode(self, *args, **kwargs):
raise NotImplementedError
def query_geometry(self, *args, **kwargs):
raise NotImplementedError
def query_color(self, *args, **kwargs):
raise NotImplementedError

View File

@ -0,0 +1,42 @@
model:
target: primitive_anything.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
params:
shape_module_cfg:
target: primitive_anything.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
params:
num_latents: 256
embed_dim: 64
point_feats: 3 # normal
num_freqs: 8
include_pi: false
heads: 12
width: 768
num_encoder_layers: 8
num_decoder_layers: 16
use_ln_post: true
init_scale: 0.25
qkv_bias: false
use_checkpoint: true
aligned_module_cfg:
target: primitive_anything.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
loss_cfg:
target: primitive_anything.michelangelo.models.tsal.loss.ContrastKLNearFar
params:
contrast_weight: 0.1
near_weight: 0.1
kl_weight: 0.001
optimizer_cfg:
optimizer:
target: torch.optim.AdamW
params:
betas: [0.9, 0.99]
eps: 1.e-6
weight_decay: 1.e-2
scheduler:
target: primitive_anything.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
params:
warm_up_steps: 5000
f_start: 1.e-6
f_min: 1.e-3
f_max: 1.0

View File

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
from .misc import get_config_from_file
from .misc import instantiate_from_config

View File

@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
import torch
def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7):
mse = torch.mean((x - y) ** 2)
psnr = 10 * torch.log10(data_range / (mse + eps))
return psnr

View File

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
import os
import io
import tarfile
import json
import numpy as np
import numpy.lib.format
def mkdir(path):
os.makedirs(path, exist_ok=True)
return path
def npy_loads(data):
stream = io.BytesIO(data)
return np.lib.format.read_array(stream)
def npz_loads(data):
return np.load(io.BytesIO(data))
def json_loads(data):
return json.loads(data)
def load_json(filepath):
with open(filepath, "r") as f:
data = json.load(f)
return data
def write_json(filepath, data):
with open(filepath, "w") as f:
json.dump(data, f, indent=2)
def extract_tar(tar_path, tar_cache_folder):
with tarfile.open(tar_path, "r") as tar:
tar.extractall(path=tar_cache_folder)
tar_uids = sorted(os.listdir(tar_cache_folder))
print(f"extract tar: {tar_path} to {tar_cache_folder}")
return tar_uids

View File

@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-
import importlib
from omegaconf import OmegaConf, DictConfig, ListConfig
import torch
import torch.distributed as dist
from typing import Union
def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]:
config_file = OmegaConf.load(config_file)
if 'base_config' in config_file.keys():
if config_file['base_config'] == "default_base":
base_config = OmegaConf.create()
# base_config = get_default_config()
elif config_file['base_config'].endswith(".yaml"):
base_config = get_config_from_file(config_file['base_config'])
else:
raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.")
config_file = {key: value for key, value in config_file if key != "base_config"}
return OmegaConf.merge(base_config, config_file)
return config_file
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def get_obj_from_config(config):
if "target" not in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])
def instantiate_from_config(config, **kwargs):
if "target" not in config:
raise KeyError("Expected key `target` to instantiate.")
cls = get_obj_from_str(config["target"])
params = config.get("params", dict())
# params.update(kwargs)
# instance = cls(**params)
kwargs.update(params)
instance = cls(**kwargs)
return instance
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def all_gather_batch(tensors):
"""
Performs all_gather operation on the provided tensors.
"""
# Queue the gathered tensors
world_size = get_world_size()
# There is no need for reduction in the single-proc case
if world_size == 1:
return tensors
tensor_list = []
output_tensor = []
for tensor in tensors:
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
dist.all_gather(
tensor_all,
tensor,
async_op=False # performance opt
)
tensor_list.append(tensor_all)
for tensor_all in tensor_list:
output_tensor.append(torch.cat(tensor_all, dim=0))
return output_tensor

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,43 @@
import numpy as np
import matplotlib.pyplot as plt
# Helper functions
def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None):
colormap = plt.cm.get_cmap(colormap)
if normalize:
vmin = np.min(inp)
vmax = np.max(inp)
norm = plt.Normalize(vmin, vmax)
return colormap(norm(inp))[:, :3]
def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):
# tex dims need to be power of two.
array = np.ones((width, height, 3), dtype='float32')
# width in texels of each checker
checker_w = width / n_checkers_x
checker_h = height / n_checkers_y
for y in range(height):
for x in range(width):
color_key = int(x / checker_w) + int(y / checker_h)
if color_key % 2 == 0:
array[x, y, :] = [1., 0.874, 0.0]
else:
array[x, y, :] = [0., 0., 0.]
return array
def gen_circle(width=256, height=256):
xx, yy = np.mgrid[:width, :height]
circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2
array = np.ones((width, height, 4), dtype='float32')
array[:, :, 0] = (circle <= width)
array[:, :, 1] = (circle <= width)
array[:, :, 2] = (circle <= width)
array[:, :, 3] = circle <= width
return array

View File

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
import io
import base64
import numpy as np
from PIL import Image
def to_html_frame(content):
html_frame = f"""
<html>
<body>
{content}
</body>
</html>
"""
return html_frame
def to_single_row_table(caption: str, content: str):
table_html = f"""
<table border = "1">
<caption>{caption}</caption>
<tr>
<td>{content}</td>
</tr>
</table>
"""
return table_html
def to_image_embed_tag(image: np.ndarray):
# Convert np.ndarray to bytes
img = Image.fromarray(image)
raw_bytes = io.BytesIO()
img.save(raw_bytes, "PNG")
# Encode bytes to base64
image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8")
image_tag = f"""
<img src="data:image/png;base64,{image_base64}" alt="Embedded Image">
"""
return image_tag

View File

@ -0,0 +1,534 @@
import numpy as np
from ipywidgets import embed
import pythreejs as p3s
import uuid
from .color_util import get_colors, gen_circle, gen_checkers
EMBED_URL = "https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@1.0.1/dist/embed-amd.js"
class PyThreeJSViewer(object):
def __init__(self, settings, render_mode="WEBSITE"):
self.render_mode = render_mode
self.__update_settings(settings)
self._light = p3s.DirectionalLight(color='white', position=[0, 0, 1], intensity=0.6)
self._light2 = p3s.AmbientLight(intensity=0.5)
self._cam = p3s.PerspectiveCamera(position=[0, 0, 1], lookAt=[0, 0, 0], fov=self.__s["fov"],
aspect=self.__s["width"] / self.__s["height"], children=[self._light])
self._orbit = p3s.OrbitControls(controlling=self._cam)
self._scene = p3s.Scene(children=[self._cam, self._light2], background=self.__s["background"]) # "#4c4c80"
self._renderer = p3s.Renderer(camera=self._cam, scene=self._scene, controls=[self._orbit],
width=self.__s["width"], height=self.__s["height"],
antialias=self.__s["antialias"])
self.__objects = {}
self.__cnt = 0
def jupyter_mode(self):
self.render_mode = "JUPYTER"
def offline(self):
self.render_mode = "OFFLINE"
def website(self):
self.render_mode = "WEBSITE"
def __get_shading(self, shading):
shad = {"flat": True, "wireframe": False, "wire_width": 0.03, "wire_color": "black",
"side": 'DoubleSide', "colormap": "viridis", "normalize": [None, None],
"bbox": False, "roughness": 0.5, "metalness": 0.25, "reflectivity": 1.0,
"line_width": 1.0, "line_color": "black",
"point_color": "red", "point_size": 0.01, "point_shape": "circle",
"text_color": "red"
}
for k in shading:
shad[k] = shading[k]
return shad
def __update_settings(self, settings={}):
sett = {"width": 600, "height": 600, "antialias": True, "scale": 1.5, "background": "#ffffff",
"fov": 30}
for k in settings:
sett[k] = settings[k]
self.__s = sett
def __add_object(self, obj, parent=None):
if not parent: # Object is added to global scene and objects dict
self.__objects[self.__cnt] = obj
self.__cnt += 1
self._scene.add(obj["mesh"])
else: # Object is added to parent object and NOT to objects dict
parent.add(obj["mesh"])
self.__update_view()
if self.render_mode == "JUPYTER":
return self.__cnt - 1
elif self.render_mode == "WEBSITE":
return self
def __add_line_geometry(self, lines, shading, obj=None):
lines = lines.astype("float32", copy=False)
mi = np.min(lines, axis=0)
ma = np.max(lines, axis=0)
geometry = p3s.LineSegmentsGeometry(positions=lines.reshape((-1, 2, 3)))
material = p3s.LineMaterial(linewidth=shading["line_width"], color=shading["line_color"])
# , vertexColors='VertexColors'),
lines = p3s.LineSegments2(geometry=geometry, material=material) # type='LinePieces')
line_obj = {"geometry": geometry, "mesh": lines, "material": material,
"max": ma, "min": mi, "type": "Lines", "wireframe": None}
if obj:
return self.__add_object(line_obj, obj), line_obj
else:
return self.__add_object(line_obj)
def __update_view(self):
if len(self.__objects) == 0:
return
ma = np.zeros((len(self.__objects), 3))
mi = np.zeros((len(self.__objects), 3))
for r, obj in enumerate(self.__objects):
ma[r] = self.__objects[obj]["max"]
mi[r] = self.__objects[obj]["min"]
ma = np.max(ma, axis=0)
mi = np.min(mi, axis=0)
diag = np.linalg.norm(ma - mi)
mean = ((ma - mi) / 2 + mi).tolist()
scale = self.__s["scale"] * (diag)
self._orbit.target = mean
self._cam.lookAt(mean)
self._cam.position = [mean[0], mean[1], mean[2] + scale]
self._light.position = [mean[0], mean[1], mean[2] + scale]
self._orbit.exec_three_obj_method('update')
self._cam.exec_three_obj_method('updateProjectionMatrix')
def __get_bbox(self, v):
m = np.min(v, axis=0)
M = np.max(v, axis=0)
# Corners of the bounding box
v_box = np.array([[m[0], m[1], m[2]], [M[0], m[1], m[2]], [M[0], M[1], m[2]], [m[0], M[1], m[2]],
[m[0], m[1], M[2]], [M[0], m[1], M[2]], [M[0], M[1], M[2]], [m[0], M[1], M[2]]])
f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4],
[0, 4], [1, 5], [2, 6], [7, 3]], dtype=np.uint32)
return v_box, f_box
def __get_colors(self, v, f, c, sh):
coloring = "VertexColors"
if type(c) == np.ndarray and c.size == 3: # Single color
colors = np.ones_like(v)
colors[:, 0] = c[0]
colors[:, 1] = c[1]
colors[:, 2] = c[2]
# print("Single colors")
elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[1] == 3: # Color values for
if c.shape[0] == f.shape[0]: # faces
colors = np.hstack([c, c, c]).reshape((-1, 3))
coloring = "FaceColors"
# print("Face color values")
elif c.shape[0] == v.shape[0]: # vertices
colors = c
# print("Vertex color values")
else: # Wrong size, fallback
print("Invalid color array given! Supported are numpy arrays.", type(c))
colors = np.ones_like(v)
colors[:, 0] = 1.0
colors[:, 1] = 0.874
colors[:, 2] = 0.0
elif type(c) == np.ndarray and c.size == f.shape[0]: # Function values for faces
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
cc = get_colors(c, sh["colormap"], normalize=normalize,
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
# print(cc.shape)
colors = np.hstack([cc, cc, cc]).reshape((-1, 3))
coloring = "FaceColors"
# print("Face function values")
elif type(c) == np.ndarray and c.size == v.shape[0]: # Function values for vertices
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
colors = get_colors(c, sh["colormap"], normalize=normalize,
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
# print("Vertex function values")
else:
colors = np.ones_like(v)
colors[:, 0] = 1.0
colors[:, 1] = 0.874
colors[:, 2] = 0.0
# No color
if c is not None:
print("Invalid color array given! Supported are numpy arrays.", type(c))
return colors, coloring
def __get_point_colors(self, v, c, sh):
v_color = True
if c is None: # No color given, use global color
# conv = mpl.colors.ColorConverter()
colors = sh["point_color"] # np.array(conv.to_rgb(sh["point_color"]))
v_color = False
elif isinstance(c, str): # No color given, use global color
# conv = mpl.colors.ColorConverter()
colors = c # np.array(conv.to_rgb(c))
v_color = False
elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] == 3:
# Point color
colors = c.astype("float32", copy=False)
elif isinstance(c, np.ndarray) and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] != 3:
# Function values for vertices, but the colors are features
c_norm = np.linalg.norm(c, ord=2, axis=-1)
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
colors = get_colors(c_norm, sh["colormap"], normalize=normalize,
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
colors = colors.astype("float32", copy=False)
elif type(c) == np.ndarray and c.size == v.shape[0]: # Function color
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
colors = get_colors(c, sh["colormap"], normalize=normalize,
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
colors = colors.astype("float32", copy=False)
# print("Vertex function values")
else:
print("Invalid color array given! Supported are numpy arrays.", type(c))
colors = sh["point_color"]
v_color = False
return colors, v_color
def add_mesh(self, v, f, c=None, uv=None, n=None, shading={}, texture_data=None, **kwargs):
shading.update(kwargs)
sh = self.__get_shading(shading)
mesh_obj = {}
# it is a tet
if v.shape[1] == 3 and f.shape[1] == 4:
f_tmp = np.ndarray([f.shape[0] * 4, 3], dtype=f.dtype)
for i in range(f.shape[0]):
f_tmp[i * 4 + 0] = np.array([f[i][1], f[i][0], f[i][2]])
f_tmp[i * 4 + 1] = np.array([f[i][0], f[i][1], f[i][3]])
f_tmp[i * 4 + 2] = np.array([f[i][1], f[i][2], f[i][3]])
f_tmp[i * 4 + 3] = np.array([f[i][2], f[i][0], f[i][3]])
f = f_tmp
if v.shape[1] == 2:
v = np.append(v, np.zeros([v.shape[0], 1]), 1)
# Type adjustment vertices
v = v.astype("float32", copy=False)
# Color setup
colors, coloring = self.__get_colors(v, f, c, sh)
# Type adjustment faces and colors
c = colors.astype("float32", copy=False)
# Material and geometry setup
ba_dict = {"color": p3s.BufferAttribute(c)}
if coloring == "FaceColors":
verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
for ii in range(f.shape[0]):
# print(ii*3, f[ii])
verts[ii * 3] = v[f[ii, 0]]
verts[ii * 3 + 1] = v[f[ii, 1]]
verts[ii * 3 + 2] = v[f[ii, 2]]
v = verts
else:
f = f.astype("uint32", copy=False).ravel()
ba_dict["index"] = p3s.BufferAttribute(f, normalized=False)
ba_dict["position"] = p3s.BufferAttribute(v, normalized=False)
if uv is not None:
uv = (uv - np.min(uv)) / (np.max(uv) - np.min(uv))
if texture_data is None:
texture_data = gen_checkers(20, 20)
tex = p3s.DataTexture(data=texture_data, format="RGBFormat", type="FloatType")
material = p3s.MeshStandardMaterial(map=tex, reflectivity=sh["reflectivity"], side=sh["side"],
roughness=sh["roughness"], metalness=sh["metalness"],
flatShading=sh["flat"],
polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
ba_dict["uv"] = p3s.BufferAttribute(uv.astype("float32", copy=False))
else:
material = p3s.MeshStandardMaterial(vertexColors=coloring, reflectivity=sh["reflectivity"],
side=sh["side"], roughness=sh["roughness"], metalness=sh["metalness"],
flatShading=sh["flat"],
polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
if type(n) != type(None) and coloring == "VertexColors": # TODO: properly handle normals for FaceColors as well
ba_dict["normal"] = p3s.BufferAttribute(n.astype("float32", copy=False), normalized=True)
geometry = p3s.BufferGeometry(attributes=ba_dict)
if coloring == "VertexColors" and type(n) == type(None):
geometry.exec_three_obj_method('computeVertexNormals')
elif coloring == "FaceColors" and type(n) == type(None):
geometry.exec_three_obj_method('computeFaceNormals')
# Mesh setup
mesh = p3s.Mesh(geometry=geometry, material=material)
# Wireframe setup
mesh_obj["wireframe"] = None
if sh["wireframe"]:
wf_geometry = p3s.WireframeGeometry(mesh.geometry) # WireframeGeometry
wf_material = p3s.LineBasicMaterial(color=sh["wire_color"], linewidth=sh["wire_width"])
wireframe = p3s.LineSegments(wf_geometry, wf_material)
mesh.add(wireframe)
mesh_obj["wireframe"] = wireframe
# Bounding box setup
if sh["bbox"]:
v_box, f_box = self.__get_bbox(v)
_, bbox = self.add_edges(v_box, f_box, sh, mesh)
mesh_obj["bbox"] = [bbox, v_box, f_box]
# Object setup
mesh_obj["max"] = np.max(v, axis=0)
mesh_obj["min"] = np.min(v, axis=0)
mesh_obj["geometry"] = geometry
mesh_obj["mesh"] = mesh
mesh_obj["material"] = material
mesh_obj["type"] = "Mesh"
mesh_obj["shading"] = sh
mesh_obj["coloring"] = coloring
mesh_obj["arrays"] = [v, f, c] # TODO replays with proper storage or remove if not needed
return self.__add_object(mesh_obj)
def add_lines(self, beginning, ending, shading={}, obj=None, **kwargs):
shading.update(kwargs)
if len(beginning.shape) == 1:
if len(beginning) == 2:
beginning = np.array([[beginning[0], beginning[1], 0]])
else:
if beginning.shape[1] == 2:
beginning = np.append(
beginning, np.zeros([beginning.shape[0], 1]), 1)
if len(ending.shape) == 1:
if len(ending) == 2:
ending = np.array([[ending[0], ending[1], 0]])
else:
if ending.shape[1] == 2:
ending = np.append(
ending, np.zeros([ending.shape[0], 1]), 1)
sh = self.__get_shading(shading)
lines = np.hstack([beginning, ending])
lines = lines.reshape((-1, 3))
return self.__add_line_geometry(lines, sh, obj)
def add_edges(self, vertices, edges, shading={}, obj=None, **kwargs):
shading.update(kwargs)
if vertices.shape[1] == 2:
vertices = np.append(
vertices, np.zeros([vertices.shape[0], 1]), 1)
sh = self.__get_shading(shading)
lines = np.zeros((edges.size, 3))
cnt = 0
for e in edges:
lines[cnt, :] = vertices[e[0]]
lines[cnt + 1, :] = vertices[e[1]]
cnt += 2
return self.__add_line_geometry(lines, sh, obj)
def add_points(self, points, c=None, shading={}, obj=None, **kwargs):
shading.update(kwargs)
if len(points.shape) == 1:
if len(points) == 2:
points = np.array([[points[0], points[1], 0]])
else:
if points.shape[1] == 2:
points = np.append(
points, np.zeros([points.shape[0], 1]), 1)
sh = self.__get_shading(shading)
points = points.astype("float32", copy=False)
mi = np.min(points, axis=0)
ma = np.max(points, axis=0)
g_attributes = {"position": p3s.BufferAttribute(points, normalized=False)}
m_attributes = {"size": sh["point_size"]}
if sh["point_shape"] == "circle": # Plot circles
tex = p3s.DataTexture(data=gen_circle(16, 16), format="RGBAFormat", type="FloatType")
m_attributes["map"] = tex
m_attributes["alphaTest"] = 0.5
m_attributes["transparency"] = True
else: # Plot squares
pass
colors, v_colors = self.__get_point_colors(points, c, sh)
if v_colors: # Colors per point
m_attributes["vertexColors"] = 'VertexColors'
g_attributes["color"] = p3s.BufferAttribute(colors, normalized=False)
else: # Colors for all points
m_attributes["color"] = colors
material = p3s.PointsMaterial(**m_attributes)
geometry = p3s.BufferGeometry(attributes=g_attributes)
points = p3s.Points(geometry=geometry, material=material)
point_obj = {"geometry": geometry, "mesh": points, "material": material,
"max": ma, "min": mi, "type": "Points", "wireframe": None}
if obj:
return self.__add_object(point_obj, obj), point_obj
else:
return self.__add_object(point_obj)
def remove_object(self, obj_id):
if obj_id not in self.__objects:
print("Invalid object id. Valid ids are: ", list(self.__objects.keys()))
return
self._scene.remove(self.__objects[obj_id]["mesh"])
del self.__objects[obj_id]
self.__update_view()
def reset(self):
for obj_id in list(self.__objects.keys()).copy():
self._scene.remove(self.__objects[obj_id]["mesh"])
del self.__objects[obj_id]
self.__update_view()
def update_object(self, oid=0, vertices=None, colors=None, faces=None):
obj = self.__objects[oid]
if type(vertices) != type(None):
if obj["coloring"] == "FaceColors":
f = obj["arrays"][1]
verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
for ii in range(f.shape[0]):
# print(ii*3, f[ii])
verts[ii * 3] = vertices[f[ii, 0]]
verts[ii * 3 + 1] = vertices[f[ii, 1]]
verts[ii * 3 + 2] = vertices[f[ii, 2]]
v = verts
else:
v = vertices.astype("float32", copy=False)
obj["geometry"].attributes["position"].array = v
# self.wireframe.attributes["position"].array = v # Wireframe updates?
obj["geometry"].attributes["position"].needsUpdate = True
# obj["geometry"].exec_three_obj_method('computeVertexNormals')
if type(colors) != type(None):
colors, coloring = self.__get_colors(obj["arrays"][0], obj["arrays"][1], colors, obj["shading"])
colors = colors.astype("float32", copy=False)
obj["geometry"].attributes["color"].array = colors
obj["geometry"].attributes["color"].needsUpdate = True
if type(faces) != type(None):
if obj["coloring"] == "FaceColors":
print("Face updates are currently only possible in vertex color mode.")
return
f = faces.astype("uint32", copy=False).ravel()
print(obj["geometry"].attributes)
obj["geometry"].attributes["index"].array = f
# self.wireframe.attributes["position"].array = v # Wireframe updates?
obj["geometry"].attributes["index"].needsUpdate = True
# obj["geometry"].exec_three_obj_method('computeVertexNormals')
# self.mesh.geometry.verticesNeedUpdate = True
# self.mesh.geometry.elementsNeedUpdate = True
# self.update()
if self.render_mode == "WEBSITE":
return self
# def update(self):
# self.mesh.exec_three_obj_method('update')
# self.orbit.exec_three_obj_method('update')
# self.cam.exec_three_obj_method('updateProjectionMatrix')
# self.scene.exec_three_obj_method('update')
def add_text(self, text, shading={}, **kwargs):
shading.update(kwargs)
sh = self.__get_shading(shading)
tt = p3s.TextTexture(string=text, color=sh["text_color"])
sm = p3s.SpriteMaterial(map=tt)
text = p3s.Sprite(material=sm, scaleToTexture=True)
self._scene.add(text)
# def add_widget(self, widget, callback):
# self.widgets.append(widget)
# widget.observe(callback, names='value')
# def add_dropdown(self, options, default, desc, cb):
# widget = widgets.Dropdown(options=options, value=default, description=desc)
# self.__widgets.append(widget)
# widget.observe(cb, names="value")
# display(widget)
# def add_button(self, text, cb):
# button = widgets.Button(description=text)
# self.__widgets.append(button)
# button.on_click(cb)
# display(button)
def to_html(self, imports=True, html_frame=True):
# Bake positions (fixes centering bug in offline rendering)
if len(self.__objects) == 0:
return
ma = np.zeros((len(self.__objects), 3))
mi = np.zeros((len(self.__objects), 3))
for r, obj in enumerate(self.__objects):
ma[r] = self.__objects[obj]["max"]
mi[r] = self.__objects[obj]["min"]
ma = np.max(ma, axis=0)
mi = np.min(mi, axis=0)
diag = np.linalg.norm(ma - mi)
mean = (ma - mi) / 2 + mi
for r, obj in enumerate(self.__objects):
v = self.__objects[obj]["geometry"].attributes["position"].array
v -= mean
v += np.array([0.0, .9, 0.0]) #! to move the obj to the center of window
scale = self.__s["scale"] * (diag)
self._orbit.target = [0.0, 0.0, 0.0]
self._cam.lookAt([0.0, 0.0, 0.0])
# self._cam.position = [0.0, 0.0, scale]
self._cam.position = [0.0, 0.5, scale * 1.3] #! show four complete meshes in the window
self._light.position = [0.0, 0.0, scale]
state = embed.dependency_state(self._renderer)
# Somehow these entries are missing when the state is exported in python.
# Exporting from the GUI works, so we are inserting the missing entries.
for k in state:
if state[k]["model_name"] == "OrbitControlsModel":
state[k]["state"]["maxAzimuthAngle"] = "inf"
state[k]["state"]["maxDistance"] = "inf"
state[k]["state"]["maxZoom"] = "inf"
state[k]["state"]["minAzimuthAngle"] = "-inf"
tpl = embed.load_requirejs_template
if not imports:
embed.load_requirejs_template = ""
s = embed.embed_snippet(self._renderer, state=state, embed_url=EMBED_URL)
# s = embed.embed_snippet(self.__w, state=state)
embed.load_requirejs_template = tpl
if html_frame:
s = "<html>\n<body>\n" + s + "\n</body>\n</html>"
# Revert changes
for r, obj in enumerate(self.__objects):
v = self.__objects[obj]["geometry"].attributes["position"].array
v += mean
self.__update_view()
return s
def save(self, filename=""):
if filename == "":
uid = str(uuid.uuid4()) + ".html"
else:
filename = filename.replace(".html", "")
uid = filename + '.html'
with open(uid, "w") as f:
f.write(self.to_html())
print("Plot saved to file %s." % uid)

View File

@ -0,0 +1,87 @@
import copy
import json
import os
import numpy as np
from scipy.linalg import polar
from scipy.spatial.transform import Rotation
import open3d as o3d
import torch
from torch.utils.data import Dataset
from .utils import exists
from .utils.logger import print_log
def create_dataset(cfg_dataset):
kwargs = cfg_dataset
name = kwargs.pop('name')
dataset = get_dataset(name)(**kwargs)
print_log(f"Dataset '{name}' init: kwargs={kwargs}, len={len(dataset)}")
return dataset
def get_dataset(name):
return {
'base': PrimitiveDataset,
}[name]
SHAPE_CODE = {
'CubeBevel': 0,
'SphereSharp': 1,
'CylinderSharp': 2,
}
class PrimitiveDataset(Dataset):
def __init__(self,
pc_dir,
bs_dir,
max_length=144,
range_scale=[0, 1],
range_rotation=[-180, 180],
range_translation=[-1, 1],
rotation_type='euler',
pc_format='pc',
):
self.data_filename = os.listdir(pc_dir)
self.pc_dir = pc_dir
self.max_length = max_length
self.range_scale = range_scale
self.range_rotation = range_rotation
self.range_translation = range_translation
self.rotation_type = rotation_type
self.pc_format = pc_format
with open(os.path.join(bs_dir, 'basic_shapes.json'), 'r', encoding='utf-8') as f:
basic_shapes = json.load(f)
self.typeid_map = {
1101002001034001: 'CubeBevel',
1101002001034010: 'SphereSharp',
1101002001034002: 'CylinderSharp',
}
def __len__(self):
return len(self.data_filename)
def __getitem__(self, idx):
pc_file = os.path.join(self.pc_dir, self.data_filename[idx])
pc = o3d.io.read_point_cloud(pc_file)
model_data = {}
points = torch.from_numpy(np.asarray(pc.points)).float()
colors = torch.from_numpy(np.asarray(pc.colors)).float()
normals = torch.from_numpy(np.asarray(pc.normals)).float()
if self.pc_format == 'pc':
model_data['pc'] = torch.concatenate([points, colors], dim=-1).T
elif self.pc_format == 'pn':
model_data['pc'] = torch.concatenate([points, normals], dim=-1)
elif self.pc_format == 'pcn':
model_data['pc'] = torch.concatenate([points, colors, normals], dim=-1)
else:
raise ValueError(f'invalid pc_format: {self.pc_format}')
return model_data

View File

@ -0,0 +1,948 @@
from __future__ import annotations
from functools import partial
from math import ceil
import os
from accelerate.utils import DistributedDataParallelKwargs
from beartype.typing import Tuple, Callable, List
from einops import rearrange, repeat, reduce, pack
from gateloop_transformer import SimpleGateLoopLayer
from huggingface_hub import PyTorchModelHubMixin
import numpy as np
import open3d as o3d
from tqdm import tqdm
import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from pytorch3d.loss import chamfer_distance
from pytorch3d.transforms import euler_angles_to_matrix
from x_transformers import Decoder
from x_transformers.x_transformers import LayerIntermediates
from x_transformers.autoregressive_wrapper import eval_decorator
from .michelangelo import ShapeConditioner as ShapeConditioner_miche
from .utils import (
discretize,
undiscretize,
set_module_requires_grad_,
default,
exists,
safe_cat,
identity,
is_tensor_empty,
)
from .utils.typing import Float, Int, Bool, typecheck
# constants
DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(
find_unused_parameters = True
)
SHAPE_CODE = {
'CubeBevel': 0,
'SphereSharp': 1,
'CylinderSharp': 2,
}
BS_NAME = {
0: 'CubeBevel',
1: 'SphereSharp',
2: 'CylinderSharp',
}
# FiLM block
class FiLM(Module):
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
self.to_gamma = nn.Linear(dim, dim_out, bias = False)
self.to_beta = nn.Linear(dim, dim_out)
self.gamma_mult = nn.Parameter(torch.zeros(1,))
self.beta_mult = nn.Parameter(torch.zeros(1,))
def forward(self, x, cond):
gamma, beta = self.to_gamma(cond), self.to_beta(cond)
gamma, beta = tuple(rearrange(t, 'b d -> b 1 d') for t in (gamma, beta))
# for initializing to identity
gamma = (1 + self.gamma_mult * gamma.tanh())
beta = beta.tanh() * self.beta_mult
# classic film
return x * gamma + beta
# gateloop layers
class GateLoopBlock(Module):
def __init__(
self,
dim,
*,
depth,
use_heinsen = True
):
super().__init__()
self.gateloops = ModuleList([])
for _ in range(depth):
gateloop = SimpleGateLoopLayer(dim = dim, use_heinsen = use_heinsen)
self.gateloops.append(gateloop)
def forward(
self,
x,
cache = None
):
received_cache = exists(cache)
if is_tensor_empty(x):
return x, None
if received_cache:
prev, x = x[:, :-1], x[:, -1:]
cache = default(cache, [])
cache = iter(cache)
new_caches = []
for gateloop in self.gateloops:
layer_cache = next(cache, None)
out, new_cache = gateloop(x, cache = layer_cache, return_cache = True)
new_caches.append(new_cache)
x = x + out
if received_cache:
x = torch.cat((prev, x), dim = -2)
return x, new_caches
def top_k_2(logits, frac_num_tokens=0.1, k=None):
num_tokens = logits.shape[-1]
k = default(k, ceil(frac_num_tokens * num_tokens))
k = min(k, num_tokens)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(2, ind, val)
return probs
def soft_argmax(labels):
indices = torch.arange(labels.size(-1), dtype=labels.dtype, device=labels.device)
soft_argmax = torch.sum(labels * indices, dim=-1)
return soft_argmax
class PrimitiveTransformerDiscrete(Module, PyTorchModelHubMixin):
@typecheck
def __init__(
self,
*,
num_discrete_scale = 128,
continuous_range_scale: List[float, float] = [0, 1],
dim_scale_embed = 64,
num_discrete_rotation = 180,
continuous_range_rotation: List[float, float] = [-180, 180],
dim_rotation_embed = 64,
num_discrete_translation = 128,
continuous_range_translation: List[float, float] = [-1, 1],
dim_translation_embed = 64,
num_type = 3,
dim_type_embed = 64,
embed_order = 'ctrs',
bin_smooth_blur_sigma = 0.4,
dim: int | Tuple[int, int] = 512,
flash_attn = True,
attn_depth = 12,
attn_dim_head = 64,
attn_heads = 16,
attn_kwargs: dict = dict(
ff_glu = True,
attn_num_mem_kv = 4
),
max_primitive_len = 144,
dropout = 0.,
coarse_pre_gateloop_depth = 2,
coarse_post_gateloop_depth = 0,
coarse_adaptive_rmsnorm = False,
gateloop_use_heinsen = False,
pad_id = -1,
num_sos_tokens = None,
condition_on_shape = True,
shape_cond_with_cross_attn = False,
shape_cond_with_film = False,
shape_cond_with_cat = False,
shape_condition_model_type = 'michelangelo',
shape_condition_len = 1,
shape_condition_dim = None,
cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out shape condition
loss_weight: dict = dict(
eos = 1.0,
type = 1.0,
scale = 1.0,
rotation = 1.0,
translation = 1.0,
reconstruction = 1.0,
scale_huber = 1.0,
rotation_huber = 1.0,
translation_huber = 1.0,
),
bs_pc_dir=None,
):
super().__init__()
# feature embedding
self.num_discrete_scale = num_discrete_scale
self.continuous_range_scale = continuous_range_scale
self.discretize_scale = partial(discretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale)
self.undiscretize_scale = partial(undiscretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale)
self.scale_embed = nn.Embedding(num_discrete_scale, dim_scale_embed)
self.num_discrete_rotation = num_discrete_rotation
self.continuous_range_rotation = continuous_range_rotation
self.discretize_rotation = partial(discretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation)
self.undiscretize_rotation = partial(undiscretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation)
self.rotation_embed = nn.Embedding(num_discrete_rotation, dim_rotation_embed)
self.num_discrete_translation = num_discrete_translation
self.continuous_range_translation = continuous_range_translation
self.discretize_translation = partial(discretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation)
self.undiscretize_translation = partial(undiscretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation)
self.translation_embed = nn.Embedding(num_discrete_translation, dim_translation_embed)
self.num_type = num_type
self.type_embed = nn.Embedding(num_type, dim_type_embed)
self.embed_order = embed_order
self.bin_smooth_blur_sigma = bin_smooth_blur_sigma
# initial dimension
self.dim = dim
init_dim = 3 * (dim_scale_embed + dim_rotation_embed + dim_translation_embed) + dim_type_embed
# project into model dimension
self.project_in = nn.Linear(init_dim, dim)
num_sos_tokens = default(num_sos_tokens, 1 if not condition_on_shape or not shape_cond_with_film else 4)
assert num_sos_tokens > 0
self.num_sos_tokens = num_sos_tokens
self.sos_token = nn.Parameter(torch.randn(num_sos_tokens, dim))
# the transformer eos token
self.eos_token = nn.Parameter(torch.randn(1, dim))
self.emb_layernorm = nn.LayerNorm(dim)
self.max_seq_len = max_primitive_len
# shape condition
self.condition_on_shape = condition_on_shape
self.shape_cond_with_cross_attn = False
self.shape_cond_with_cat = False
self.shape_condition_model_type = ''
self.conditioner = None
dim_shape = None
if condition_on_shape:
assert shape_cond_with_cross_attn or shape_cond_with_film or shape_cond_with_cat
self.shape_cond_with_cross_attn = shape_cond_with_cross_attn
self.shape_cond_with_cat = shape_cond_with_cat
self.shape_condition_model_type = shape_condition_model_type
if 'michelangelo' in shape_condition_model_type:
self.conditioner = ShapeConditioner_miche(dim_latent=shape_condition_dim)
self.to_cond_dim = nn.Linear(self.conditioner.dim_model_out * 2, self.conditioner.dim_latent)
self.to_cond_dim_head = nn.Linear(self.conditioner.dim_model_out, self.conditioner.dim_latent)
else:
raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}')
dim_shape = self.conditioner.dim_latent
set_module_requires_grad_(self.conditioner, False)
self.shape_coarse_film_cond = FiLM(dim_shape, dim) if shape_cond_with_film else identity
self.coarse_gateloop_block = GateLoopBlock(dim, depth=coarse_pre_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_pre_gateloop_depth > 0 else None
self.coarse_post_gateloop_block = GateLoopBlock(dim, depth=coarse_post_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_post_gateloop_depth > 0 else None
self.coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm
self.decoder = Decoder(
dim=dim,
depth=attn_depth,
heads=attn_heads,
attn_dim_head=attn_dim_head,
attn_flash=flash_attn,
attn_dropout=dropout,
ff_dropout=dropout,
use_adaptive_rmsnorm=coarse_adaptive_rmsnorm,
dim_condition=dim_shape,
cross_attend=self.shape_cond_with_cross_attn,
cross_attn_dim_context=dim_shape,
cross_attn_num_mem_kv=cross_attn_num_mem_kv,
**attn_kwargs
)
# to logits
self.to_eos_logits = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, 1)
)
self.to_type_logits = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, num_type)
)
self.to_translation_logits = nn.Sequential(
nn.Linear(dim + dim_type_embed, dim),
nn.ReLU(),
nn.Linear(dim, 3 * num_discrete_translation)
)
self.to_rotation_logits = nn.Sequential(
nn.Linear(dim + dim_type_embed + 3 * dim_translation_embed, dim),
nn.ReLU(),
nn.Linear(dim, 3 * num_discrete_rotation)
)
self.to_scale_logits = nn.Sequential(
nn.Linear(dim + dim_type_embed + 3 * (dim_translation_embed + dim_rotation_embed), dim),
nn.ReLU(),
nn.Linear(dim, 3 * num_discrete_scale)
)
self.pad_id = pad_id
bs_pc_map = {}
for bs_name, type_code in SHAPE_CODE.items():
pc = o3d.io.read_point_cloud(os.path.join(bs_pc_dir, f'SM_GR_BS_{bs_name}_001.ply'))
bs_pc_map[type_code] = torch.from_numpy(np.asarray(pc.points)).float()
bs_pc_list = []
for i in range(len(bs_pc_map)):
bs_pc_list.append(bs_pc_map[i])
self.bs_pc = torch.stack(bs_pc_list, dim=0)
self.rotation_matrix_align_coord = euler_angles_to_matrix(
torch.Tensor([np.pi/2, 0, 0]), 'XYZ').unsqueeze(0).unsqueeze(0)
@property
def device(self):
return next(self.parameters()).device
@typecheck
@torch.no_grad()
def embed_pc(self, pc: Tensor):
if 'michelangelo' in self.shape_condition_model_type:
pc_head, pc_embed = self.conditioner(shape=pc)
pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2).detach()
else:
raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}')
return pc_embed
@typecheck
def recon_primitives(
self,
scale_logits: Float['b np 3 nd'],
rotation_logits: Float['b np 3 nd'],
translation_logits: Float['b np 3 nd'],
type_logits: Int['b np nd'],
primitive_mask: Bool['b np']
):
recon_scale = self.undiscretize_scale(scale_logits.argmax(dim=-1))
recon_scale = recon_scale.masked_fill(~primitive_mask.unsqueeze(-1), float('nan'))
recon_rotation = self.undiscretize_rotation(rotation_logits.argmax(dim=-1))
recon_rotation = recon_rotation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan'))
recon_translation = self.undiscretize_translation(translation_logits.argmax(dim=-1))
recon_translation = recon_translation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan'))
recon_type_code = type_logits.argmax(dim=-1)
recon_type_code = recon_type_code.masked_fill(~primitive_mask, -1)
return {
'scale': recon_scale,
'rotation': recon_rotation,
'translation': recon_translation,
'type_code': recon_type_code
}
@typecheck
def sample_primitives(
self,
scale: Float['b np 3 nd'],
rotation: Float['b np 3 nd'],
translation: Float['b np 3 nd'],
type_code: Int['b np nd'],
next_embed: Float['b 1 nd'],
temperature: float = 1.,
filter_logits_fn: Callable = top_k_2,
filter_kwargs: dict = dict()
):
def sample_func(logits):
if logits.ndim == 4:
enable_squeeze = True
logits = logits.squeeze(1)
else:
enable_squeeze = False
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
if temperature == 0.:
sample = filtered_logits.argmax(dim=-1)
else:
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.zeros((probs.shape[0], probs.shape[1]), dtype=torch.long, device=probs.device)
for b_i in range(probs.shape[0]):
sample[b_i] = torch.multinomial(probs[b_i], 1).squeeze()
if enable_squeeze:
sample = sample.unsqueeze(1)
return sample
next_type_logits = self.to_type_logits(next_embed)
next_type_code = sample_func(next_type_logits)
type_code_new, _ = pack([type_code, next_type_code], 'b *')
type_embed = self.type_embed(next_type_code)
next_embed_packed, _ = pack([next_embed, type_embed], 'b np *')
next_translation_logits = rearrange(self.to_translation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_translation)
next_discretize_translation = sample_func(next_translation_logits)
next_translation = self.undiscretize_translation(next_discretize_translation)
translation_new, _ = pack([translation, next_translation], 'b * nd')
next_translation_embed = self.translation_embed(next_discretize_translation)
next_embed_packed, _ = pack([next_embed_packed, next_translation_embed], 'b np *')
next_rotation_logits = rearrange(self.to_rotation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_rotation)
next_discretize_rotation = sample_func(next_rotation_logits)
next_rotation = self.undiscretize_rotation(next_discretize_rotation)
rotation_new, _ = pack([rotation, next_rotation], 'b * nd')
next_rotation_embed = self.rotation_embed(next_discretize_rotation)
next_embed_packed, _ = pack([next_embed_packed, next_rotation_embed], 'b np *')
next_scale_logits = rearrange(self.to_scale_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_scale)
next_discretize_scale = sample_func(next_scale_logits)
next_scale = self.undiscretize_scale(next_discretize_scale)
scale_new, _ = pack([scale, next_scale], 'b * nd')
return (
scale_new,
rotation_new,
translation_new,
type_code_new
)
@eval_decorator
@torch.no_grad()
@typecheck
def generate(
self,
batch_size: int | None = None,
filter_logits_fn: Callable = top_k_2,
filter_kwargs: dict = dict(),
temperature: float = 1.,
scale: Float['b np 3'] | None = None,
rotation: Float['b np 3'] | None = None,
translation: Float['b np 3'] | None = None,
type_code: Int['b np'] | None = None,
pc: Tensor | None = None,
pc_embed: Tensor | None = None,
cache_kv = True,
max_seq_len = None,
):
max_seq_len = default(max_seq_len, self.max_seq_len)
if exists(scale) and exists(rotation) and exists(translation) and exists(type_code):
assert not exists(batch_size)
assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1]
assert scale.shape[1] <= self.max_seq_len
batch_size = scale.shape[0]
if self.condition_on_shape:
assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in'
if exists(pc):
pc_embed = self.embed_pc(pc)
batch_size = default(batch_size, pc_embed.shape[0])
batch_size = default(batch_size, 1)
scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device))
rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device))
translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device))
type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device))
curr_length = scale.shape[1]
cache = None
eos_codes = None
for i in tqdm(range(curr_length, max_seq_len)):
can_eos = i != 0
output = self.forward(
scale=scale,
rotation=rotation,
translation=translation,
type_code=type_code,
pc_embed=pc_embed,
return_loss=False,
return_cache=cache_kv,
append_eos=False,
cache=cache
)
if cache_kv:
next_embed, cache = output
else:
next_embed = output
(
scale,
rotation,
translation,
type_code
) = self.sample_primitives(
scale,
rotation,
translation,
type_code,
next_embed,
temperature=temperature,
filter_logits_fn=filter_logits_fn,
filter_kwargs=filter_kwargs
)
next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1)
next_eos_code = (F.sigmoid(next_eos_logits) > 0.5)
eos_codes = safe_cat([eos_codes, next_eos_code], 1)
if can_eos and eos_codes.any(dim=-1).all():
break
# mask out to padding anything after the first eos
mask = eos_codes.float().cumsum(dim=-1) >= 1
# concat cur_length to mask
mask = torch.cat((torch.zeros((batch_size, curr_length), dtype=torch.bool, device=self.device), mask), dim=-1)
type_code = type_code.masked_fill(mask, self.pad_id)
scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id)
rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id)
translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id)
recon_primitives = {
'scale': scale,
'rotation': rotation,
'translation': translation,
'type_code': type_code
}
primitive_mask = ~eos_codes
return recon_primitives, primitive_mask
@eval_decorator
@torch.no_grad()
@typecheck
def generate_w_recon_loss(
self,
batch_size: int | None = None,
filter_logits_fn: Callable = top_k_2,
filter_kwargs: dict = dict(),
temperature: float = 1.,
scale: Float['b np 3'] | None = None,
rotation: Float['b np 3'] | None = None,
translation: Float['b np 3'] | None = None,
type_code: Int['b np'] | None = None,
pc: Tensor | None = None,
pc_embed: Tensor | None = None,
cache_kv = True,
max_seq_len = None,
single_directional = True,
):
max_seq_len = default(max_seq_len, self.max_seq_len)
if exists(scale) and exists(rotation) and exists(translation) and exists(type_code):
assert not exists(batch_size)
assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1]
assert scale.shape[1] <= self.max_seq_len
batch_size = scale.shape[0]
if self.condition_on_shape:
assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in'
if exists(pc):
pc_embed = self.embed_pc(pc)
batch_size = default(batch_size, pc_embed.shape[0])
batch_size = default(batch_size, 1)
assert batch_size == 1 # TODO: support any batch size
scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device))
rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device))
translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device))
type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device))
curr_length = scale.shape[1]
cache = None
eos_codes = None
last_recon_loss = 1
for i in tqdm(range(curr_length, max_seq_len)):
can_eos = i != 0
output = self.forward(
scale=scale,
rotation=rotation,
translation=translation,
type_code=type_code,
pc_embed=pc_embed,
return_loss=False,
return_cache=cache_kv,
append_eos=False,
cache=cache
)
if cache_kv:
next_embed, cache = output
else:
next_embed = output
(
scale_new,
rotation_new,
translation_new,
type_code_new
) = self.sample_primitives(
scale,
rotation,
translation,
type_code,
next_embed,
temperature=temperature,
filter_logits_fn=filter_logits_fn,
filter_kwargs=filter_kwargs
)
next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1)
next_eos_code = (F.sigmoid(next_eos_logits) > 0.5)
eos_codes = safe_cat([eos_codes, next_eos_code], 1)
if can_eos and eos_codes.any(dim=-1).all():
scale, rotation, translation, type_code = (
scale_new, rotation_new, translation_new, type_code_new)
break
recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc, single_directional)
if recon_loss < last_recon_loss:
last_recon_loss = recon_loss
scale, rotation, translation, type_code = (
scale_new, rotation_new, translation_new, type_code_new)
else:
best_recon_loss = recon_loss
best_primitives = dict(
scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new)
success_flag = False
print(f'last_recon_loss:{last_recon_loss}, recon_loss:{recon_loss} -> to find better primitive')
for try_i in range(5):
(
scale_new,
rotation_new,
translation_new,
type_code_new
) = self.sample_primitives(
scale,
rotation,
translation,
type_code,
next_embed,
temperature=1.0,
filter_logits_fn=filter_logits_fn,
filter_kwargs=filter_kwargs
)
recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc)
print(f'[try_{try_i}] last_recon_loss:{last_recon_loss}, best_recon_loss:{best_recon_loss}, cur_recon_loss:{recon_loss}')
if recon_loss < last_recon_loss:
last_recon_loss = recon_loss
scale, rotation, translation, type_code = (
scale_new, rotation_new, translation_new, type_code_new)
success_flag = True
break
else:
if recon_loss < best_recon_loss:
best_recon_loss = recon_loss
best_primitives = dict(
scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new)
if not success_flag:
last_recon_loss = best_recon_loss
scale, rotation, translation, type_code = (
best_primitives['scale'], best_primitives['rotation'], best_primitives['translation'], best_primitives['type_code'])
print(f'new_last_recon_loss:{last_recon_loss}')
# mask out to padding anything after the first eos
mask = eos_codes.float().cumsum(dim=-1) >= 1
type_code = type_code.masked_fill(mask, self.pad_id)
scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id)
rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id)
translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id)
recon_primitives = {
'scale': scale,
'rotation': rotation,
'translation': translation,
'type_code': type_code
}
primitive_mask = ~eos_codes
return recon_primitives, primitive_mask
@typecheck
def encode(
self,
*,
scale: Float['b np 3'],
rotation: Float['b np 3'],
translation: Float['b np 3'],
type_code: Int['b np'],
primitive_mask: Bool['b np'],
return_primitives = False
):
"""
einops:
b - batch
np - number of primitives
c - coordinates (3)
d - embed dim
"""
# compute feature embedding
discretize_scale = self.discretize_scale(scale)
scale_embed = self.scale_embed(discretize_scale)
scale_embed = rearrange(scale_embed, 'b np c d -> b np (c d)')
discretize_rotation = self.discretize_rotation(rotation)
rotation_embed = self.rotation_embed(discretize_rotation)
rotation_embed = rearrange(rotation_embed, 'b np c d -> b np (c d)')
discretize_translation = self.discretize_translation(translation)
translation_embed = self.translation_embed(discretize_translation)
translation_embed = rearrange(translation_embed, 'b np c d -> b np (c d)')
type_embed = self.type_embed(type_code.masked_fill(~primitive_mask, 0))
# combine all features and project into model dimension
if self.embed_order == 'srtc':
primitive_embed, _ = pack([scale_embed, rotation_embed, translation_embed, type_embed], 'b np *')
else:
primitive_embed, _ = pack([type_embed, translation_embed, rotation_embed, scale_embed], 'b np *')
primitive_embed = self.project_in(primitive_embed)
primitive_embed = primitive_embed.masked_fill(~primitive_mask.unsqueeze(-1), 0.)
if not return_primitives:
return primitive_embed
primitive_embed_unpacked = {
'scale': scale_embed,
'rotation': rotation_embed,
'translation': translation_embed,
'type_code': type_embed
}
primitives_gt = {
'scale': discretize_scale,
'rotation': discretize_rotation,
'translation': discretize_translation,
'type_code': type_code
}
return primitive_embed, primitive_embed_unpacked, primitives_gt
@typecheck
def compute_chamfer_distance(
self,
scale_pred: Float['b np 3'],
rotation_pred: Float['b np 3'],
translation_pred: Float['b np 3'],
type_pred: Int['b np'],
primitive_mask: Bool['b np'],
pc: Tensor, # b, num_points, c
single_directional = True
):
scale_pred = scale_pred.float()
rotation_pred = rotation_pred.float()
translation_pred = translation_pred.float()
pc_pred = apply_transformation(self.bs_pc.to(type_pred.device)[type_pred], scale_pred, torch.deg2rad(rotation_pred), translation_pred)
pc_pred = torch.matmul(pc_pred, self.rotation_matrix_align_coord.to(type_pred.device))
pc_pred_flat = rearrange(pc_pred, 'b np p c -> b (np p) c')
pc_pred_sampled = random_sample_pc(pc_pred_flat, primitive_mask.sum(dim=-1, keepdim=True), n_points=self.bs_pc.shape[1])
if single_directional:
recon_loss, _ = chamfer_distance(pc[:, :, :3].float(), pc_pred_sampled.float(), single_directional=True) # single directional
else:
recon_loss, _ = chamfer_distance(pc_pred_sampled.float(), pc[:, :, :3].float())
return recon_loss
def forward(
self,
*,
scale: Float['b np 3'],
rotation: Float['b np 3'],
translation: Float['b np 3'],
type_code: Int['b np'],
loss_reduction: str = 'mean',
return_cache = False,
append_eos = True,
cache: LayerIntermediates | None = None,
pc: Tensor | None = None,
pc_embed: Tensor | None = None,
**kwargs
):
primitive_mask = reduce(scale != self.pad_id, 'b np 3 -> b np', 'all')
if scale.shape[1] > 0:
codes, primitives_embeds, primitives_gt = self.encode(
scale=scale,
rotation=rotation,
translation=translation,
type_code=type_code,
primitive_mask=primitive_mask,
return_primitives=True
)
else:
codes = torch.empty((scale.shape[0], 0, self.dim), dtype=torch.float32, device=self.device)
# handle shape conditions
attn_context_kwargs = dict()
if self.condition_on_shape:
assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in'
if exists(pc):
if 'michelangelo' in self.shape_condition_model_type:
pc_head, pc_embed = self.conditioner(shape=pc)
pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2)
else:
raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}')
assert pc_embed.shape[0] == codes.shape[0], 'batch size of point cloud is not equal to the batch size of the primitive codes'
pooled_pc_embed = pc_embed.mean(dim=1) # (b, shape_condition_dim)
if self.shape_cond_with_cross_attn:
attn_context_kwargs = dict(
context=pc_embed
)
if self.coarse_adaptive_rmsnorm:
attn_context_kwargs.update(
condition=pooled_pc_embed
)
batch, seq_len, _ = codes.shape # (b, np, dim)
device = codes.device
assert seq_len <= self.max_seq_len, f'received codes of length {seq_len} but needs to be less than or equal to set max_seq_len {self.max_seq_len}'
if append_eos:
assert exists(codes)
code_lens = primitive_mask.sum(dim=-1)
codes = pad_tensor(codes)
batch_arange = torch.arange(batch, device=device)
batch_arange = rearrange(batch_arange, '... -> ... 1')
code_lens = rearrange(code_lens, '... -> ... 1')
codes[batch_arange, code_lens] = self.eos_token # (b, np+1, dim)
primitive_codes = codes # (b, np, dim)
primitive_codes_len = primitive_codes.shape[-2]
(
coarse_cache,
coarse_gateloop_cache,
coarse_post_gateloop_cache,
) = cache if exists(cache) else ((None,) * 3)
if not exists(cache):
sos = repeat(self.sos_token, 'n d -> b n d', b=batch)
if self.shape_cond_with_cat:
sos, _ = pack([pc_embed, sos], 'b * d')
primitive_codes, packed_sos_shape = pack([sos, primitive_codes], 'b * d') # (b, n_sos+np, dim)
# condition primitive codes with shape if needed
if self.condition_on_shape:
primitive_codes = self.shape_coarse_film_cond(primitive_codes, pooled_pc_embed)
# attention on primitive codes (coarse)
if exists(self.coarse_gateloop_block):
primitive_codes, coarse_gateloop_cache = self.coarse_gateloop_block(primitive_codes, cache=coarse_gateloop_cache)
attended_primitive_codes, coarse_cache = self.decoder( # (b, n_sos+np, dim)
primitive_codes,
cache=coarse_cache,
return_hiddens=True,
**attn_context_kwargs
)
if exists(self.coarse_post_gateloop_block):
primitive_codes, coarse_post_gateloop_cache = self.coarse_post_gateloop_block(primitive_codes, cache=coarse_post_gateloop_cache)
embed = attended_primitive_codes[:, -(primitive_codes_len + 1):] # (b, np+1, dim)
if not return_cache:
return embed[:, -1:]
next_cache = (
coarse_cache,
coarse_gateloop_cache,
coarse_post_gateloop_cache
)
return embed[:, -1:], next_cache
def pad_tensor(tensor):
if tensor.dim() == 3:
bs, seq_len, dim = tensor.shape
padding = torch.zeros((bs, 1, dim), dtype=tensor.dtype, device=tensor.device)
elif tensor.dim() == 2:
bs, seq_len = tensor.shape
padding = torch.zeros((bs, 1), dtype=tensor.dtype, device=tensor.device)
else:
raise ValueError('Unsupported tensor shape: {}'.format(tensor.shape))
return torch.cat([tensor, padding], dim=1)
def apply_transformation(pc, scale, rotation_vector, translation):
bs, np, num_points, _ = pc.shape
scaled_pc = pc * scale.unsqueeze(2)
rotation_matrix = euler_angles_to_matrix(rotation_vector.view(-1, 3), 'XYZ').view(bs, np, 3, 3) # euler tmp
rotated_pc = torch.einsum('bnij,bnpj->bnpi', rotation_matrix, scaled_pc)
transformed_pc = rotated_pc + translation.unsqueeze(2)
return transformed_pc
def random_sample_pc(pc, max_lens, n_points=10000):
bs = max_lens.shape[0]
max_len = max_lens.max().item() * n_points
random_values = torch.rand(bs, max_len, device=max_lens.device)
mask = torch.arange(max_len).expand(bs, max_len).to(max_lens.device) < (max_lens * n_points)
masked_random_values = random_values * mask.float()
_, indices = torch.topk(masked_random_values, n_points, dim=1)
return pc[torch.arange(bs).unsqueeze(1), indices]

View File

@ -0,0 +1,275 @@
from math import ceil
from pathlib import Path
import os
import re
from beartype.typing import Tuple
from einops import rearrange, repeat
from toolz import valmap
import torch
from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F
import yaml
from .typing import typecheck
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def path_mkdir(path):
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
return path
def load_yaml(path, default_path=None):
path = path_exists(path)
with open(path, mode='r') as fp:
cfg_s = yaml.load(fp, Loader=yaml.FullLoader)
if default_path is not None:
default_path = path_exists(default_path)
with open(default_path, mode='r') as fp:
cfg = yaml.load(fp, Loader=yaml.FullLoader)
else:
# try current dir default
default_path = path.parent / 'default.yml'
if default_path.exists():
with open(default_path, mode='r') as fp:
cfg = yaml.load(fp, Loader=yaml.FullLoader)
else:
cfg = {}
update_recursive(cfg, cfg_s)
return cfg
def dump_yaml(cfg, path):
with open(path, mode='w') as f:
return yaml.safe_dump(cfg, f)
def update_recursive(dict1, dict2):
''' Update two config dictionaries recursively.
Args:
dict1 (dict): first dictionary to be updated
dict2 (dict): second dictionary which entries should be used
'''
for k, v in dict2.items():
if k not in dict1:
dict1[k] = dict()
if isinstance(v, dict):
update_recursive(dict1[k], v)
else:
dict1[k] = v
def load_latest_checkpoint(checkpoint_dir):
pattern = re.compile(rf".+\.ckpt\.(\d+)\.pt")
max_epoch = -1
latest_checkpoint = None
for filename in os.listdir(checkpoint_dir):
match = pattern.match(filename)
if match:
num_epoch = int(match.group(1))
if num_epoch > max_epoch:
max_epoch = num_epoch
latest_checkpoint = checkpoint_dir / filename
if not exists(latest_checkpoint):
raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}")
checkpoint = torch.load(latest_checkpoint)
return checkpoint, latest_checkpoint
def torch_to(inp, device, non_blocking=False):
nb = non_blocking # set to True when doing distributed jobs
if isinstance(inp, torch.Tensor):
return inp.to(device, non_blocking=nb)
elif isinstance(inp, (list, tuple)):
return type(inp)(map(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp))
elif isinstance(inp, dict):
return valmap(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp)
else:
raise NotImplementedError
# helper functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def first(it):
return it[0]
def identity(t, *args, **kwargs):
return t
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def is_empty(x):
return len(x) == 0
def is_tensor_empty(t: Tensor):
return t.numel() == 0
def set_module_requires_grad_(
module: Module,
requires_grad: bool
):
for param in module.parameters():
param.requires_grad = requires_grad
def l1norm(t):
return F.normalize(t, dim = -1, p = 1)
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def safe_cat(tensors, dim):
tensors = [*filter(exists, tensors)]
if len(tensors) == 0:
return None
elif len(tensors) == 1:
return first(tensors)
return torch.cat(tensors, dim = dim)
def pad_at_dim(t, padding, dim = -1, value = 0):
ndim = t.ndim
right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1)
zeros = (0, 0) * right_dims
return F.pad(t, (*zeros, *padding), value = value)
def pad_to_length(t, length, dim = -1, value = 0, right = True):
curr_length = t.shape[dim]
remainder = length - curr_length
if remainder <= 0:
return t
padding = (0, remainder) if right else (remainder, 0)
return pad_at_dim(t, padding, dim = dim, value = value)
def masked_mean(tensor, mask, dim = -1, eps = 1e-5):
if not exists(mask):
return tensor.mean(dim = dim)
mask = rearrange(mask, '... -> ... 1')
tensor = tensor.masked_fill(~mask, 0.)
total_el = mask.sum(dim = dim)
num = tensor.sum(dim = dim)
den = total_el.float().clamp(min = eps)
mean = num / den
mean = mean.masked_fill(total_el == 0, 0.)
return mean
def cycle(dl):
while True:
for data in dl:
yield data
def maybe_del(d: dict, *keys):
for key in keys:
if key not in d:
continue
del d[key]
# tensor helper functions
@typecheck
def discretize(
t: Tensor,
*,
continuous_range: Tuple[float, float],
num_discrete: int = 128
) -> Tensor:
lo, hi = continuous_range
assert hi > lo
t = (t - lo) / (hi - lo)
t *= num_discrete
t -= 0.5
return t.round().long().clamp(min = 0, max = num_discrete - 1)
@typecheck
def undiscretize(
t: Tensor,
*,
continuous_range = Tuple[float, float],
num_discrete: int = 128
) -> Tensor:
lo, hi = continuous_range
assert hi > lo
t = t.float()
t += 0.5
t /= num_discrete
return t * (hi - lo) + lo
@typecheck
def gaussian_blur_1d(
t: Tensor,
*,
sigma: float = 1.,
kernel_size: int = 5
) -> Tensor:
_, _, channels, device, dtype = *t.shape, t.device, t.dtype
width = int(ceil(sigma * kernel_size))
width += (width + 1) % 2
half_width = width // 2
distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device)
gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2))
gaussian = l1norm(gaussian)
kernel = repeat(gaussian, 'n -> c 1 n', c = channels)
t = rearrange(t, 'b n c -> b c n')
out = F.conv1d(t, kernel, padding = half_width, groups = channels)
return rearrange(out, 'b c n -> b n c')
@typecheck
def scatter_mean(
tgt: Tensor,
indices: Tensor,
src = Tensor,
*,
dim: int = -1,
eps: float = 1e-5
):
"""
todo: update to pytorch 2.1 and try https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_
"""
num = tgt.scatter_add(dim, indices, src)
den = torch.zeros_like(tgt).scatter_add(dim, indices, torch.ones_like(src))
return num / den.clamp(min = eps)
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
elif prob == 0:
return torch.zeros(shape, device = device, dtype = torch.bool)
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

View File

@ -0,0 +1,65 @@
import logging
import time
import os
class Verbose:
mute = False
def print_log(s, logger=None, level='info'):
if Verbose.mute:
return None
if logger is None:
logger = logging.getLogger('trainer')
if level == 'info':
print_info(s)
logger.info(s)
elif level == 'warning':
print_warning(s)
logger.warning(s)
elif level == 'error':
print_error(s)
logger.error(s)
else:
raise NotImplementedError
def create_logger(log_dir, name='trainer'):
assert os.path.exists(log_dir), 'log_dir {} does not exist.'
logger = logging.getLogger(name)
file_path = log_dir / '{}.log'.format(name)
hdlr = logging.FileHandler(file_path)
formatter = logging.Formatter('[%(asctime)s] %(levelname)s: %(message)s')
hdlr.setFormatter(formatter)
logger.addHandler(hdlr)
logger.setLevel(logging.INFO)
return logger
class TerminalColors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def get_time():
return time.strftime('%Y-%m-%d %H:%M:%S')
def print_info(s):
print(TerminalColors.OKBLUE + '[' + get_time() + '] ' + str(s) + TerminalColors.ENDC)
def print_warning(s):
print(TerminalColors.WARNING + '[' + get_time() + '] WARN ' + str(s) + TerminalColors.ENDC)
def print_error(s):
print(TerminalColors.FAIL + '[' + get_time() + '] ERROR ' + str(s) + TerminalColors.ENDC)

View File

@ -0,0 +1,57 @@
from environs import Env
from torch import Tensor
from beartype import beartype
from beartype.door import is_bearable
from jaxtyping import (
Float,
Int,
Bool,
jaxtyped
)
# environment
env = Env()
env.read_env()
# function
def always(value):
def inner(*args, **kwargs):
return value
return inner
def identity(t):
return t
# jaxtyping is a misnomer, works for pytorch
class TorchTyping:
def __init__(self, abstract_dtype):
self.abstract_dtype = abstract_dtype
def __getitem__(self, shapes: str):
return self.abstract_dtype[Tensor, shapes]
Float = TorchTyping(Float)
Int = TorchTyping(Int)
Bool = TorchTyping(Bool)
# use env variable TYPECHECK to control whether to use beartype + jaxtyping
should_typecheck = env.bool('TYPECHECK', False)
typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity
beartype_isinstance = is_bearable if should_typecheck else always(True)
__all__ = [
Float,
Int,
Bool,
typecheck,
beartype_isinstance
]

21
requirements.txt Normal file
View File

@ -0,0 +1,21 @@
accelerate
beartype
einops
gateloop_transformer
matplotlib
pytorch_custom_utils
x_transformers
toolz
environs
jaxtyping
omegaconf
transformers
open3d
trimesh
pytorch_lightning
scikit-image
opencv-python
mesh2sdf
seaborn
mesh_to_sdf
point_cloud_utils

104
sample.py Normal file
View File

@ -0,0 +1,104 @@
import argparse
from functools import partial
import glob
import multiprocessing
import os
import time
from mesh_to_sdf import get_surface_point_cloud
import numpy as np
import open3d as o3d
import trimesh
os.environ["PYOPENGL_PLATFORM"] = "egl"
def sample_surface_points(mesh, number_of_points=500000, surface_point_method="scan", sign_method="normal",
scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False,
return_surface_pc_normals=False, normalized=False):
sample_start = time.time()
if surface_point_method == "sample" and sign_method == "depth":
print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.")
sign_method = "normal"
surface_start = time.time()
bound_radius = 1 if normalized else None
surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution,
sample_point_count,
calculate_normals=sign_method == "normal" or return_gradients)
surface_end = time.time()
print("surface point cloud time cost :", surface_end - surface_start)
normal_start = time.time()
if return_surface_pc_normals:
rng = np.random.default_rng()
assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0]
indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True)
points = surface_point_cloud.points[indices]
normals = surface_point_cloud.normals[indices]
surface_points = np.concatenate([points, normals], axis=-1)
else:
surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True)
normal_end = time.time()
print("normal time cost :", normal_end - normal_start)
sample_end = time.time()
print("sample surface point time cost :", sample_end - sample_start)
return surface_points
def process_surface_point(mesh, number_of_near_surface_points, return_surface_pc_normals=False):
mesh = trimesh.load(mesh, force="mesh")
surface_point = sample_surface_points(mesh, number_of_near_surface_points, return_surface_pc_normals=return_surface_pc_normals)
return surface_point
def sample_model(model_path, num_points, return_surface_pc_normals=True):
pc_out_path = os.path.join(args.output_dir, os.path.basename(model_path)).replace(f".{args.postfix}", ".ply")
if os.path.exists(pc_out_path):
print(f"{pc_out_path}: exists!")
return
try:
surface_point = process_surface_point(model_path, num_points, return_surface_pc_normals=return_surface_pc_normals)
coords = surface_point[:, :3]
normals = surface_point[:, 3:]
assert (np.linalg.norm(np.asarray(normals), axis=-1) > 0.99).all()
assert (np.linalg.norm(np.asarray(normals), axis=-1) < 1.01).all()
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(coords)
pcd.colors = o3d.utility.Vector3dVector(np.ones_like(coords)*0.5)
pcd.normals = o3d.utility.Vector3dVector(normals)
o3d.io.write_point_cloud(pc_out_path, pcd)
print(f"write_point_cloud: {pc_out_path}")
except:
print(f"[ERROR] file: {pc_out_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, default="./results/infer/JsonResults")
parser.add_argument("--output_dir", type=str, default="./results/infer/PointClouds")
parser.add_argument("--num_points", type=int, default=10000)
parser.add_argument("--postfix", type=str, default="glb")
args = parser.parse_args()
if not os.path.exists(args.input_dir):
print("Invalid input!")
exit(1)
if os.path.exists(args.output_dir):
print(f"path: {args.output_dir} exists!")
# exit(1)
else:
os.makedirs(args.output_dir)
model_prefix = os.path.join(args.input_dir, f"*.{args.postfix}")
model_path_list = sorted(list(glob.glob(model_prefix)))
sample_model_func = partial(sample_model, num_points=args.num_points, return_surface_pc_normals=True)
with multiprocessing.Pool(16) as pool:
pool.map(sample_model_func, model_path_list)