mirror of
				https://github.com/PrimitiveAnything/PrimitiveAnything.git
				synced 2025-11-04 18:02:17 +08:00 
			
		
		
		
	init
This commit is contained in:
		
						commit
						87c3ed5e40
					
				
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
				
			|||||||
 | 
					**/__pycache__/
 | 
				
			||||||
 | 
					ckpt
 | 
				
			||||||
 | 
					data
 | 
				
			||||||
 | 
					results
 | 
				
			||||||
							
								
								
									
										112
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,112 @@
 | 
				
			|||||||
 | 
					# PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<a href="https://primitiveanything.github.io/"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Github&color=blue&logo=github-pages"></a> <a href="#"><img src="https://img.shields.io/badge/ArXiv-250x.xxxxx-brightgreen"></a> <a href="https://huggingface.co/hyz317/PrimitiveAnything"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a> <a href="https://huggingface.co/spaces/hyz317/PrimitiveAnything"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<img src="./assets/teaser.jpg" width="100%">
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 🔥 Updates
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**[2025/05/07]** test dataset, code, pretrained checkpoints and Gradio demo are released!
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 🔍 Table of Contents
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- [⚙️ Deployment](#deployment)
 | 
				
			||||||
 | 
					- [🖥️ Run PrimitiveAnything](#run-pa)
 | 
				
			||||||
 | 
					- [📝 Citation](#citation)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<a name="deployment"></a>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## ⚙️ Deployment
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Set up a Python environment and install the required packages:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					conda create -n primitiveanything python=3.9 -y
 | 
				
			||||||
 | 
					conda activate primitiveanything
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Install torch, torchvision based on your machine configuration
 | 
				
			||||||
 | 
					pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Install other dependencies
 | 
				
			||||||
 | 
					pip install -r requirements.txt
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Then download data and pretrained weights:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					1. **Our Model Weights**: 
 | 
				
			||||||
 | 
					   Download from our 🤗 Hugging Face repository ([download here](https://huggingface.co/hyz317/PrimitiveAnything)) and place them in `./ckpt/`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					2. **Michelangelo’s Point Cloud Encoder**: 
 | 
				
			||||||
 | 
					   Download weights from [Michelangelo’s Hugging Face repo](https://huggingface.co/Maikou/Michelangelo/tree/main/checkpoints/aligned_shape_latents) and save them to `./ckpt/`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					3. **Demo and test data**:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   Download from this [Google Drive link](https://drive.google.com/file/d/1FZZjk0OvzETD5j4OODEghS_YppcS_ZbM/view?usp=sharing), then decompress the files into `./data/`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					After downloading and organizing the files, your project directory should look like this:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					- data/
 | 
				
			||||||
 | 
					    ├── basic_shapes_norm/
 | 
				
			||||||
 | 
					    ├── basic_shapes_norm_pc10000/
 | 
				
			||||||
 | 
					    ├── demo_glb/                   # Demo files in GLB format
 | 
				
			||||||
 | 
					    └── test_pc/                    # Test point cloud data
 | 
				
			||||||
 | 
					- ckpt/
 | 
				
			||||||
 | 
					    ├── mesh-transformer.ckpt.60.pt # Our model checkpoint
 | 
				
			||||||
 | 
					    └── shapevae-256.ckpt           # Michelangelo ShapeVAE checkpoint
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<a name="run-pa"></a>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 🖥️ Run PrimitiveAnything
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Demo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					python demo.py --input ./data/demo_glb --log_path ./results/demo
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Notes:**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- `--input` accepts either:
 | 
				
			||||||
 | 
					  - Any standard 3D file (GLB, OBJ, etc.)
 | 
				
			||||||
 | 
					  - A directory containing multiple 3D files
 | 
				
			||||||
 | 
					- For optimal results with fine structures, we automatically apply marching cubes and dilation operations (which differs from testing and evaluation). This prevents quality degradation in thin areas.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Testing and Evaluation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					# Autoregressive generation
 | 
				
			||||||
 | 
					python infer.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Sample point clouds from predictions  
 | 
				
			||||||
 | 
					python sample.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Calculate evaluation metrics
 | 
				
			||||||
 | 
					python eval.py
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<a name="citation"></a>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 📝 Citation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					If you find our work useful, please kindly cite:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					@article{ye2025primitiveanything,
 | 
				
			||||||
 | 
					  title={PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer},
 | 
				
			||||||
 | 
					  author={Ye, Jingwen and He, Yuze and Zhou, Yanning and Zhu, Yiqin and Xiao, Kaiwen and Liu, Yong-Jin and Yang, Wei and Han, Xiao},
 | 
				
			||||||
 | 
					  journal={arXiv preprint arXiv:250x.xxxxx},
 | 
				
			||||||
 | 
					  year={2025}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								assets/teaser.jpg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								assets/teaser.jpg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 1.3 MiB  | 
							
								
								
									
										52
									
								
								configs/infer.yml
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										52
									
								
								configs/infer.yml
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					dataset:
 | 
				
			||||||
 | 
					  name: base
 | 
				
			||||||
 | 
					  pc_dir: ./data/test_pc
 | 
				
			||||||
 | 
					  bs_dir: data/basic_shapes_norm
 | 
				
			||||||
 | 
					  max_length: 144
 | 
				
			||||||
 | 
					  range_scale: [0, 1]
 | 
				
			||||||
 | 
					  range_rotation: [-180, 180]
 | 
				
			||||||
 | 
					  range_translation: [-1, 1]
 | 
				
			||||||
 | 
					  rotation_type: euler
 | 
				
			||||||
 | 
					  pc_format: pn
 | 
				
			||||||
 | 
					model:
 | 
				
			||||||
 | 
					  attn_depth: 6
 | 
				
			||||||
 | 
					  attn_heads: 6
 | 
				
			||||||
 | 
					  bin_smooth_blur_sigma: -1
 | 
				
			||||||
 | 
					  bs_pc_dir: data/basic_shapes_norm_pc10000
 | 
				
			||||||
 | 
					  coarse_pre_gateloop_depth: 3
 | 
				
			||||||
 | 
					  continuous_range_rotation:
 | 
				
			||||||
 | 
					  - -181
 | 
				
			||||||
 | 
					  - 181
 | 
				
			||||||
 | 
					  continuous_range_scale:
 | 
				
			||||||
 | 
					  - 0
 | 
				
			||||||
 | 
					  - 1
 | 
				
			||||||
 | 
					  continuous_range_translation:
 | 
				
			||||||
 | 
					  - -1
 | 
				
			||||||
 | 
					  - 1
 | 
				
			||||||
 | 
					  dim: 768
 | 
				
			||||||
 | 
					  dim_rotation_embed: 16
 | 
				
			||||||
 | 
					  dim_scale_embed: 16
 | 
				
			||||||
 | 
					  dim_translation_embed: 16
 | 
				
			||||||
 | 
					  dim_type_embed: 48
 | 
				
			||||||
 | 
					  dropout: 0.0
 | 
				
			||||||
 | 
					  embed_order: ctrs
 | 
				
			||||||
 | 
					  gateloop_use_heinsen: false
 | 
				
			||||||
 | 
					  loss_weight:
 | 
				
			||||||
 | 
					    eos: 1.0
 | 
				
			||||||
 | 
					    reconstruction: 1.0
 | 
				
			||||||
 | 
					    rotation: 1.0
 | 
				
			||||||
 | 
					    scale: 1.0
 | 
				
			||||||
 | 
					    translation: 1.0
 | 
				
			||||||
 | 
					    type: 1.0
 | 
				
			||||||
 | 
					  max_primitive_len: 144
 | 
				
			||||||
 | 
					  name: discrete
 | 
				
			||||||
 | 
					  num_discrete_rotation: 181
 | 
				
			||||||
 | 
					  num_discrete_scale: 128
 | 
				
			||||||
 | 
					  num_discrete_translation: 128
 | 
				
			||||||
 | 
					  num_type: 3
 | 
				
			||||||
 | 
					  shape_cond_with_cat: true
 | 
				
			||||||
 | 
					  shape_cond_with_cross_attn: false
 | 
				
			||||||
 | 
					  shape_cond_with_film: false
 | 
				
			||||||
 | 
					  shape_condition_dim: 768
 | 
				
			||||||
 | 
					  shape_condition_len: 77
 | 
				
			||||||
 | 
					  shape_condition_model_type: michelangelo
 | 
				
			||||||
							
								
								
									
										342
									
								
								demo.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										342
									
								
								demo.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,342 @@
 | 
				
			|||||||
 | 
					import os
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import glob
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import yaml
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import trimesh
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import mesh2sdf.core
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import skimage.measure
 | 
				
			||||||
 | 
					import seaborn as sns
 | 
				
			||||||
 | 
					from scipy.spatial.transform import Rotation
 | 
				
			||||||
 | 
					from mesh_to_sdf import get_surface_point_cloud
 | 
				
			||||||
 | 
					from accelerate.utils import set_seed
 | 
				
			||||||
 | 
					from accelerate import Accelerator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from primitive_anything.utils import path_mkdir, count_parameters
 | 
				
			||||||
 | 
					from primitive_anything.utils.logger import print_log
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					os.environ['PYOPENGL_PLATFORM'] = 'egl'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_args():
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(description='Process 3D model files')
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        '--input',
 | 
				
			||||||
 | 
					        type=str,
 | 
				
			||||||
 | 
					        default='./data/demo_glb/',
 | 
				
			||||||
 | 
					        help='Input file or directory path (default: ./data/demo_glb/)'
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        '--log_path',
 | 
				
			||||||
 | 
					        type=str,
 | 
				
			||||||
 | 
					        default='./results/demo',
 | 
				
			||||||
 | 
					        help='Output directory path (default: results/demo)'
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    return parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_input_files(input_path):
 | 
				
			||||||
 | 
					    if os.path.isfile(input_path):
 | 
				
			||||||
 | 
					        return [input_path]
 | 
				
			||||||
 | 
					    elif os.path.isdir(input_path):
 | 
				
			||||||
 | 
					        return glob.glob(os.path.join(input_path, '*'))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise ValueError(f"Input path {input_path} is neither a file nor a directory")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					args = parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Get input files (keeping your original variable name)
 | 
				
			||||||
 | 
					input_3ds = get_input_files(args.input)
 | 
				
			||||||
 | 
					if not input_3ds:
 | 
				
			||||||
 | 
					    raise FileNotFoundError(f"No files found at input path: {args.input}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Create output directory (keeping your original variable name)
 | 
				
			||||||
 | 
					LOG_PATH = args.log_path
 | 
				
			||||||
 | 
					os.makedirs(LOG_PATH, exist_ok=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(f"Found {len(input_3ds)} input files")
 | 
				
			||||||
 | 
					print(f"Output directory: {LOG_PATH}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CODE_SHAPE = {
 | 
				
			||||||
 | 
					    0: 'SM_GR_BS_CubeBevel_001.ply',
 | 
				
			||||||
 | 
					    1: 'SM_GR_BS_SphereSharp_001.ply',
 | 
				
			||||||
 | 
					    2: 'SM_GR_BS_CylinderSharp_001.ply',
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					shapename_map = {
 | 
				
			||||||
 | 
					    'SM_GR_BS_CubeBevel_001.ply': 1101002001034001,
 | 
				
			||||||
 | 
					    'SM_GR_BS_SphereSharp_001.ply': 1101002001034010,
 | 
				
			||||||
 | 
					    'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### config
 | 
				
			||||||
 | 
					bs_dir = 'data/basic_shapes_norm'
 | 
				
			||||||
 | 
					config_path = './configs/infer.yml'
 | 
				
			||||||
 | 
					AR_checkpoint_path = './ckpt/mesh-transformer.ckpt.60.pt'
 | 
				
			||||||
 | 
					temperature= 0.0
 | 
				
			||||||
 | 
					#### init model
 | 
				
			||||||
 | 
					mesh_bs = {}
 | 
				
			||||||
 | 
					for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')):
 | 
				
			||||||
 | 
					    bs_name = os.path.basename(bs_path)
 | 
				
			||||||
 | 
					    bs = trimesh.load(bs_path)
 | 
				
			||||||
 | 
					    bs.visual.uv = np.clip(bs.visual.uv, 0, 1)
 | 
				
			||||||
 | 
					    bs.visual = bs.visual.to_color()
 | 
				
			||||||
 | 
					    mesh_bs[bs_name] = bs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_model(cfg_model):
 | 
				
			||||||
 | 
					    kwargs = cfg_model
 | 
				
			||||||
 | 
					    name = kwargs.pop('name')
 | 
				
			||||||
 | 
					    model = get_model(name)(**kwargs)
 | 
				
			||||||
 | 
					    print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs))
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete
 | 
				
			||||||
 | 
					def get_model(name):
 | 
				
			||||||
 | 
					    return {
 | 
				
			||||||
 | 
					        'discrete': PrimitiveTransformerDiscrete,
 | 
				
			||||||
 | 
					    }[name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					with open(config_path, mode='r') as fp:
 | 
				
			||||||
 | 
					    AR_train_cfg = yaml.load(fp, Loader=yaml.FullLoader)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					AR_checkpoint = torch.load(AR_checkpoint_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					transformer = create_model(AR_train_cfg['model'])
 | 
				
			||||||
 | 
					transformer.load_state_dict(AR_checkpoint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					device = torch.device('cuda')
 | 
				
			||||||
 | 
					accelerator = Accelerator(
 | 
				
			||||||
 | 
					    mixed_precision='fp16',
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					transformer = accelerator.prepare(transformer)
 | 
				
			||||||
 | 
					transformer.eval()
 | 
				
			||||||
 | 
					transformer.bs_pc = transformer.bs_pc.cuda()
 | 
				
			||||||
 | 
					transformer.rotation_matrix_align_coord = transformer.rotation_matrix_align_coord.cuda()
 | 
				
			||||||
 | 
					print('model loaded to device')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def sample_surface_points(mesh, number_of_points=500000, surface_point_method='scan', sign_method='normal',
 | 
				
			||||||
 | 
					                          scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False,
 | 
				
			||||||
 | 
					                          return_surface_pc_normals=False, normalized=False):
 | 
				
			||||||
 | 
					    sample_start = time.time()
 | 
				
			||||||
 | 
					    if surface_point_method == 'sample' and sign_method == 'depth':
 | 
				
			||||||
 | 
					        print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.")
 | 
				
			||||||
 | 
					        sign_method = 'normal'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    surface_start = time.time()
 | 
				
			||||||
 | 
					    bound_radius = 1 if normalized else None
 | 
				
			||||||
 | 
					    surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution,
 | 
				
			||||||
 | 
					                                                  sample_point_count,
 | 
				
			||||||
 | 
					                                                  calculate_normals=sign_method == 'normal' or return_gradients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    surface_end = time.time()
 | 
				
			||||||
 | 
					    print('surface point cloud time cost :', surface_end - surface_start)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    normal_start = time.time()
 | 
				
			||||||
 | 
					    if return_surface_pc_normals:
 | 
				
			||||||
 | 
					        rng = np.random.default_rng()
 | 
				
			||||||
 | 
					        assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0]
 | 
				
			||||||
 | 
					        indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True)
 | 
				
			||||||
 | 
					        points = surface_point_cloud.points[indices]
 | 
				
			||||||
 | 
					        normals = surface_point_cloud.normals[indices]
 | 
				
			||||||
 | 
					        surface_points = np.concatenate([points, normals], axis=-1)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True)
 | 
				
			||||||
 | 
					    normal_end = time.time()
 | 
				
			||||||
 | 
					    print('normal time cost :', normal_end - normal_start)
 | 
				
			||||||
 | 
					    sample_end = time.time()
 | 
				
			||||||
 | 
					    print('sample surface point time cost :', sample_end - sample_start)
 | 
				
			||||||
 | 
					    return surface_points
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def normalize_vertices(vertices, scale=0.9):
 | 
				
			||||||
 | 
					    bbmin, bbmax = vertices.min(0), vertices.max(0)
 | 
				
			||||||
 | 
					    center = (bbmin + bbmax) * 0.5
 | 
				
			||||||
 | 
					    scale = 2.0 * scale / (bbmax - bbmin).max()
 | 
				
			||||||
 | 
					    vertices = (vertices - center) * scale
 | 
				
			||||||
 | 
					    return vertices, center, scale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def export_to_watertight(normalized_mesh, octree_depth: int = 7):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					        Convert the non-watertight mesh to watertight.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            input_path (str): normalized path
 | 
				
			||||||
 | 
					            octree_depth (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            mesh(trimesh.Trimesh): watertight mesh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					    size = 2 ** octree_depth
 | 
				
			||||||
 | 
					    level = 2 / size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices)
 | 
				
			||||||
 | 
					    sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size)
 | 
				
			||||||
 | 
					    vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # watertight mesh
 | 
				
			||||||
 | 
					    vertices = vertices / size * 2 - 1 # -1 to 1
 | 
				
			||||||
 | 
					    vertices = vertices / to_orig_scale + to_orig_center
 | 
				
			||||||
 | 
					    mesh = trimesh.Trimesh(vertices, faces, normals=normals)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return mesh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def process_mesh_to_surface_pc(mesh_list, marching_cubes=False, dilated_offset=0.0, sample_num=10000):
 | 
				
			||||||
 | 
					    # mesh_list : list of trimesh
 | 
				
			||||||
 | 
					    pc_normal_list = []
 | 
				
			||||||
 | 
					    return_mesh_list = []
 | 
				
			||||||
 | 
					    for mesh in mesh_list:
 | 
				
			||||||
 | 
					        if marching_cubes:
 | 
				
			||||||
 | 
					            mesh = export_to_watertight(mesh)
 | 
				
			||||||
 | 
					            print("MC over!")
 | 
				
			||||||
 | 
					        if dilated_offset > 0:
 | 
				
			||||||
 | 
					            new_vertices = mesh.vertices + mesh.vertex_normals * dilated_offset
 | 
				
			||||||
 | 
					            mesh.vertices = new_vertices
 | 
				
			||||||
 | 
					            print("dilate over!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        mesh.merge_vertices()
 | 
				
			||||||
 | 
					        mesh.update_faces(mesh.unique_faces())
 | 
				
			||||||
 | 
					        mesh.fix_normals()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return_mesh_list.append(mesh)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pc_normal = np.asarray(sample_surface_points(mesh, sample_num, return_surface_pc_normals=True))
 | 
				
			||||||
 | 
					        pc_normal_list.append(pc_normal)
 | 
				
			||||||
 | 
					        print("process mesh success")
 | 
				
			||||||
 | 
					    return pc_normal_list, return_mesh_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					####    utils
 | 
				
			||||||
 | 
					def euler_to_quat(euler):
 | 
				
			||||||
 | 
					    return Rotation.from_euler('XYZ', euler, degrees=True).as_quat()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def SRT_quat_to_matrix(scale, quat, translation):
 | 
				
			||||||
 | 
					    rotation_matrix = Rotation.from_quat(quat).as_matrix()
 | 
				
			||||||
 | 
					    transform_matrix = np.eye(4)
 | 
				
			||||||
 | 
					    transform_matrix[:3, :3] = rotation_matrix * scale
 | 
				
			||||||
 | 
					    transform_matrix[:3, 3] = translation
 | 
				
			||||||
 | 
					    return transform_matrix
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def write_output(primitives, name):
 | 
				
			||||||
 | 
					    out_json = {}
 | 
				
			||||||
 | 
					    out_json['operation'] = 0
 | 
				
			||||||
 | 
					    out_json['type'] = 1
 | 
				
			||||||
 | 
					    out_json['scene_id'] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_group = []
 | 
				
			||||||
 | 
					    model_scene = trimesh.Scene()
 | 
				
			||||||
 | 
					    color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0])
 | 
				
			||||||
 | 
					    color_map = (np.array(color_map) * 255).astype("uint8")
 | 
				
			||||||
 | 
					    for idx, (scale, rotation, translation, type_code) in enumerate(zip(
 | 
				
			||||||
 | 
					        primitives['scale'].squeeze().cpu().numpy(),
 | 
				
			||||||
 | 
					        primitives['rotation'].squeeze().cpu().numpy(),
 | 
				
			||||||
 | 
					        primitives['translation'].squeeze().cpu().numpy(),
 | 
				
			||||||
 | 
					        primitives['type_code'].squeeze().cpu().numpy()
 | 
				
			||||||
 | 
					    )):
 | 
				
			||||||
 | 
					        if type_code == -1:
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					        bs_name = CODE_SHAPE[type_code]
 | 
				
			||||||
 | 
					        new_block = {}
 | 
				
			||||||
 | 
					        new_block['type_id'] = shapename_map[bs_name]
 | 
				
			||||||
 | 
					        new_block['data'] = {}
 | 
				
			||||||
 | 
					        new_block['data']['location'] = translation.tolist()
 | 
				
			||||||
 | 
					        new_block['data']['rotation'] = euler_to_quat(rotation).tolist()
 | 
				
			||||||
 | 
					        new_block['data']['scale'] = scale.tolist()
 | 
				
			||||||
 | 
					        new_block['data']['color'] = ['808080']
 | 
				
			||||||
 | 
					        new_group.append(new_block)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation)
 | 
				
			||||||
 | 
					        bs = mesh_bs[bs_name].copy().apply_transform(trans)
 | 
				
			||||||
 | 
					        new_vertex_colors = np.repeat(color_map[idx:idx+1], bs.visual.vertex_colors.shape[0], axis=0)
 | 
				
			||||||
 | 
					        bs.visual.vertex_colors[:, :3] = new_vertex_colors
 | 
				
			||||||
 | 
					        vertices = bs.vertices.copy()
 | 
				
			||||||
 | 
					        vertices[:, 1] = bs.vertices[:, 2]
 | 
				
			||||||
 | 
					        vertices[:, 2] = -bs.vertices[:, 1]
 | 
				
			||||||
 | 
					        bs.vertices = vertices
 | 
				
			||||||
 | 
					        model_scene.add_geometry(bs)
 | 
				
			||||||
 | 
					    out_json['group'] = new_group
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    json_path = os.path.join(LOG_PATH, f'output_{name}.json')
 | 
				
			||||||
 | 
					    with open(json_path, 'w') as json_file:
 | 
				
			||||||
 | 
					        json.dump(out_json, json_file, indent=4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    glb_path = os.path.join(LOG_PATH, f'output_{name}.glb')
 | 
				
			||||||
 | 
					    model_scene.export(glb_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return glb_path, out_json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False, do_marching_cubes=False, postprocess='none'):
 | 
				
			||||||
 | 
					    t1 = time.time()
 | 
				
			||||||
 | 
					    set_seed(sample_seed)
 | 
				
			||||||
 | 
					    input_mesh = trimesh.load(input_3d, force='mesh')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # scale mesh
 | 
				
			||||||
 | 
					    vertices = input_mesh.vertices
 | 
				
			||||||
 | 
					    bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
 | 
				
			||||||
 | 
					    vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
 | 
				
			||||||
 | 
					    vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6
 | 
				
			||||||
 | 
					    input_mesh.vertices = vertices
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pc_list, mesh_list = process_mesh_to_surface_pc(
 | 
				
			||||||
 | 
					        [input_mesh],
 | 
				
			||||||
 | 
					        marching_cubes=do_marching_cubes,
 | 
				
			||||||
 | 
					        dilated_offset=dilated_offset
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    pc_normal = pc_list[0] # 10000, 6
 | 
				
			||||||
 | 
					    mesh = mesh_list[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pc_coor = pc_normal[:, :3]
 | 
				
			||||||
 | 
					    normals = pc_normal[:, 3:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if dilated_offset > 0:
 | 
				
			||||||
 | 
					        # scale mesh and pc
 | 
				
			||||||
 | 
					        vertices = mesh.vertices
 | 
				
			||||||
 | 
					        bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
 | 
				
			||||||
 | 
					        vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
 | 
				
			||||||
 | 
					        vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6
 | 
				
			||||||
 | 
					        mesh.vertices = vertices
 | 
				
			||||||
 | 
					        pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
 | 
				
			||||||
 | 
					        pc_coor = pc_coor / (bounds[1] - bounds[0]).max() * 1.6
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    input_save_name = os.path.join(LOG_PATH, f'processed_{os.path.basename(input_3d)}')
 | 
				
			||||||
 | 
					    mesh.export(input_save_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), 'normals should be unit vectors, something wrong'
 | 
				
			||||||
 | 
					    normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    input_pc = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with accelerator.autocast():
 | 
				
			||||||
 | 
					        if postprocess == 'postprocess1':
 | 
				
			||||||
 | 
					            recon_primitives, mask = transformer.generate_w_recon_loss(pc=input_pc, temperature=temperature, single_directional=True)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    output_glb, output_json = write_output(recon_primitives, os.path.basename(input_3d)[:-4])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return input_save_name, output_glb, output_json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					dilated_offset = 0.015
 | 
				
			||||||
 | 
					do_marching_cubes = True
 | 
				
			||||||
 | 
					postprocess = 'postprocess1'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					for input_3d in input_3ds:
 | 
				
			||||||
 | 
					    print(f"processing: {input_3d}")
 | 
				
			||||||
 | 
					    preprocess_model_obj, output_model_obj, output_model_json = do_inference(
 | 
				
			||||||
 | 
					        input_3d,
 | 
				
			||||||
 | 
					        dilated_offset=dilated_offset,
 | 
				
			||||||
 | 
					        do_marching_cubes=do_marching_cubes,
 | 
				
			||||||
 | 
					        postprocess=postprocess
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
							
								
								
									
										113
									
								
								eval.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								eval.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,113 @@
 | 
				
			|||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					from collections import defaultdict
 | 
				
			||||||
 | 
					import glob
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import trimesh
 | 
				
			||||||
 | 
					import point_cloud_utils as pcu
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def voxelize(points, voxel_size=0.1):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Converts a set of 3D points to a voxel grid representation.
 | 
				
			||||||
 | 
					    Points are quantized to the nearest voxel center.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters:
 | 
				
			||||||
 | 
					    - points: (N, 3) numpy array of 3D points
 | 
				
			||||||
 | 
					    - voxel_size: the size of each voxel (adjust this depending on your data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					    - voxels: Set of unique voxel coordinates
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Quantize the points to the nearest voxel
 | 
				
			||||||
 | 
					    quantized_points = np.floor(points / voxel_size).astype(int)
 | 
				
			||||||
 | 
					    # Use a set to get unique voxel coordinates
 | 
				
			||||||
 | 
					    voxels = set(map(tuple, quantized_points))
 | 
				
			||||||
 | 
					    return voxels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def calculate_iou(model_vox, target_vox, voxel_size=0.1):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Calculate the IoU (Intersection over Union) between two point clouds.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters:
 | 
				
			||||||
 | 
					    - model_vox: (N, 3) numpy array of the first point cloud
 | 
				
			||||||
 | 
					    - target_vox: (M, 3) numpy array of the second point cloud
 | 
				
			||||||
 | 
					    - voxel_size: Size of the voxels (default is 0.1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					    - iou: Intersection over Union (IoU) score
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Voxelize both point clouds
 | 
				
			||||||
 | 
					    model_voxels = voxelize(model_vox, voxel_size)
 | 
				
			||||||
 | 
					    target_voxels = voxelize(target_vox, voxel_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Calculate intersection and union
 | 
				
			||||||
 | 
					    intersection = len(model_voxels.intersection(target_voxels))
 | 
				
			||||||
 | 
					    union = len(model_voxels.union(target_voxels))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute IoU
 | 
				
			||||||
 | 
					    iou = intersection / union if union > 0 else 0.0
 | 
				
			||||||
 | 
					    return iou
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser.add_argument("--input_dir", type=str, default="./results/infer/PointClouds")
 | 
				
			||||||
 | 
					    parser.add_argument("--target_dir", type=str, default="./data/test_pc")
 | 
				
			||||||
 | 
					    parser.add_argument("--detail", action="store_true")
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not os.path.exists(args.input_dir) or not os.path.exists(args.target_dir):
 | 
				
			||||||
 | 
					        print("Invalid input!")
 | 
				
			||||||
 | 
					        exit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_prefix = os.path.join(args.input_dir, "*.ply")
 | 
				
			||||||
 | 
					    model_path_list = sorted(list(glob.glob(model_prefix)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    distance_json = {}
 | 
				
			||||||
 | 
					    distance_list = []
 | 
				
			||||||
 | 
					    emu_distance_list = []
 | 
				
			||||||
 | 
					    hausdorff_distance_list = []
 | 
				
			||||||
 | 
					    voxel_iou_list = []
 | 
				
			||||||
 | 
					    for model_path in tqdm(model_path_list):
 | 
				
			||||||
 | 
					        model_name = os.path.basename(model_path)
 | 
				
			||||||
 | 
					        target_path = os.path.join(args.target_dir, model_name)
 | 
				
			||||||
 | 
					        if not os.path.exists(target_path):
 | 
				
			||||||
 | 
					            print(f"{target_path}: not found!")
 | 
				
			||||||
 | 
					            exit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model_pc = np.array(trimesh.load(model_path).vertices)
 | 
				
			||||||
 | 
					        target_pc = np.array(trimesh.load(target_path).vertices)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        distance = pcu.chamfer_distance(model_pc, target_pc)
 | 
				
			||||||
 | 
					        model_pc_downsampled = model_pc[np.random.choice(model_pc.shape[0], 1000, replace=False)]
 | 
				
			||||||
 | 
					        target_pc_downsampled = target_pc[np.random.choice(target_pc.shape[0], 1000, replace=False)]
 | 
				
			||||||
 | 
					        emu_distance, _ = pcu.earth_movers_distance(model_pc_downsampled, target_pc_downsampled)
 | 
				
			||||||
 | 
					        hausdorff_distance = pcu.hausdorff_distance(model_pc, target_pc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        iou = calculate_iou(model_pc, target_pc, voxel_size=1/32.)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        distance_list.append(distance)
 | 
				
			||||||
 | 
					        emu_distance_list.append(emu_distance)
 | 
				
			||||||
 | 
					        hausdorff_distance_list.append(hausdorff_distance)
 | 
				
			||||||
 | 
					        voxel_iou_list.append(iou)
 | 
				
			||||||
 | 
					        model_id = os.path.splitext(model_name)[0]
 | 
				
			||||||
 | 
					        distance_json[model_id] = distance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print(f"{model_id}: chamfer distance: {distance:.3f}, earth movers distance: {emu_distance:.3f}, hausdorff distance: {hausdorff_distance:.3f}, voxel IoU: {iou:.3f}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    distance_json["mean"] = np.mean(distance_list)
 | 
				
			||||||
 | 
					    distance_json["mean_emu"] = np.mean(emu_distance_list)
 | 
				
			||||||
 | 
					    distance_json["mean_hausdorff"] = np.mean(hausdorff_distance_list)
 | 
				
			||||||
 | 
					    distance_json["mean_voxel_iou"] = np.mean(voxel_iou_list)
 | 
				
			||||||
 | 
					    print(f"mean chamfer distance: {np.mean(distance_list)}")
 | 
				
			||||||
 | 
					    print(f"mean earth movers distance: {np.mean(emu_distance_list)}")
 | 
				
			||||||
 | 
					    print(f"mean hausdorff distance: {np.mean(hausdorff_distance_list)}")
 | 
				
			||||||
 | 
					    print(f"mean voxel IoU: {np.mean(voxel_iou_list)}")
 | 
				
			||||||
 | 
					    with open(os.path.join(args.input_dir, "distance.json"), "w") as json_file:
 | 
				
			||||||
 | 
					        json.dump(distance_json, json_file, indent=4)
 | 
				
			||||||
							
								
								
									
										172
									
								
								infer.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										172
									
								
								infer.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,172 @@
 | 
				
			|||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import glob
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import yaml
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from scipy.spatial.transform import Rotation
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import trimesh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from primitive_anything.primitive_dataset import create_dataset
 | 
				
			||||||
 | 
					from primitive_anything.utils import torch_to, count_parameters
 | 
				
			||||||
 | 
					from primitive_anything.utils.logger import create_logger, print_log
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CODE_SHAPE = {
 | 
				
			||||||
 | 
					    0: 'SM_GR_BS_CubeBevel_001.ply',
 | 
				
			||||||
 | 
					    1: 'SM_GR_BS_SphereSharp_001.ply',
 | 
				
			||||||
 | 
					    2: 'SM_GR_BS_CylinderSharp_001.ply',
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					shapename_map = {
 | 
				
			||||||
 | 
					    'SM_GR_BS_CubeBevel_001.ply': 1101002001034001,
 | 
				
			||||||
 | 
					    'SM_GR_BS_SphereSharp_001.ply': 1101002001034010,
 | 
				
			||||||
 | 
					    'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					bs_dir = 'data/basic_shapes_norm'
 | 
				
			||||||
 | 
					mesh_bs = {}
 | 
				
			||||||
 | 
					for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')):
 | 
				
			||||||
 | 
					    bs_name = os.path.basename(bs_path)
 | 
				
			||||||
 | 
					    bs = trimesh.load(bs_path)
 | 
				
			||||||
 | 
					    bs.visual.uv = np.clip(bs.visual.uv, 0, 1)
 | 
				
			||||||
 | 
					    bs.visual = bs.visual.to_color()
 | 
				
			||||||
 | 
					    mesh_bs[bs_name] = bs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_model(cfg_model):
 | 
				
			||||||
 | 
					    kwargs = cfg_model
 | 
				
			||||||
 | 
					    name = kwargs.pop('name')
 | 
				
			||||||
 | 
					    model = get_model(name)(**kwargs)
 | 
				
			||||||
 | 
					    print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs))
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete
 | 
				
			||||||
 | 
					def get_model(name):
 | 
				
			||||||
 | 
					    return {
 | 
				
			||||||
 | 
					        'discrete': PrimitiveTransformerDiscrete,
 | 
				
			||||||
 | 
					    }[name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def euler_to_quat(euler):
 | 
				
			||||||
 | 
					    return Rotation.from_euler('XYZ', euler, degrees=True).as_quat()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def rotvec_to_quat(rotvec):
 | 
				
			||||||
 | 
					    return Rotation.from_rotvec(rotvec, degrees=True).as_quat()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def SRT_quat_to_matrix(scale, quat, translation):
 | 
				
			||||||
 | 
					    rotation_matrix = Rotation.from_quat(quat).as_matrix()
 | 
				
			||||||
 | 
					    transform_matrix = np.eye(4)
 | 
				
			||||||
 | 
					    transform_matrix[:3, :3] = rotation_matrix * scale
 | 
				
			||||||
 | 
					    transform_matrix[:3, 3] = translation
 | 
				
			||||||
 | 
					    return transform_matrix
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def write_json(primitives, shapename_map, out_path):
 | 
				
			||||||
 | 
					    out_json = {}
 | 
				
			||||||
 | 
					    out_json['operation'] = 0
 | 
				
			||||||
 | 
					    out_json['type'] = 1
 | 
				
			||||||
 | 
					    out_json['scene_id'] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_group = []
 | 
				
			||||||
 | 
					    model_scene = trimesh.Scene()
 | 
				
			||||||
 | 
					    for scale, rotation, translation, type_code in zip(
 | 
				
			||||||
 | 
					        primitives['scale'].squeeze().cpu().numpy(),
 | 
				
			||||||
 | 
					        primitives['rotation'].squeeze().cpu().numpy(),
 | 
				
			||||||
 | 
					        primitives['translation'].squeeze().cpu().numpy(),
 | 
				
			||||||
 | 
					        primitives['type_code'].squeeze().cpu().numpy()
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        if type_code == -1:
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					        bs_name = CODE_SHAPE[type_code]
 | 
				
			||||||
 | 
					        new_block = {}
 | 
				
			||||||
 | 
					        new_block['type_id'] = shapename_map[bs_name]
 | 
				
			||||||
 | 
					        new_block['data'] = {}
 | 
				
			||||||
 | 
					        new_block['data']['location'] = translation.tolist()
 | 
				
			||||||
 | 
					        new_block['data']['rotation'] = euler_to_quat(rotation).tolist()
 | 
				
			||||||
 | 
					        new_block['data']['scale'] = scale.tolist()
 | 
				
			||||||
 | 
					        new_block['data']['color'] = ['808080']
 | 
				
			||||||
 | 
					        new_group.append(new_block)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if new_block['type_id'] == 1101002001034001:
 | 
				
			||||||
 | 
					            cur_color = "#2FA9FF"
 | 
				
			||||||
 | 
					        elif new_block['type_id'] == 1101002001034002:
 | 
				
			||||||
 | 
					            cur_color = "#FFC203"
 | 
				
			||||||
 | 
					        elif new_block['type_id'] == 1101002001034010:
 | 
				
			||||||
 | 
					            cur_color = "#FF8A9C"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def hex_to_rgb(hex_color):
 | 
				
			||||||
 | 
					            hex_color = hex_color.lstrip('#')
 | 
				
			||||||
 | 
					            return np.array([
 | 
				
			||||||
 | 
					                int(hex_color[0:2], 16),  # R
 | 
				
			||||||
 | 
					                int(hex_color[2:4], 16),  # G
 | 
				
			||||||
 | 
					                int(hex_color[4:6], 16),  # B
 | 
				
			||||||
 | 
					            ], dtype=np.uint8)[None]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation)
 | 
				
			||||||
 | 
					        bs = mesh_bs[bs_name].copy().apply_transform(trans)
 | 
				
			||||||
 | 
					        new_vertex_colors = np.repeat(hex_to_rgb(cur_color), bs.visual.vertex_colors.shape[0], axis=0)
 | 
				
			||||||
 | 
					        bs.visual.vertex_colors[:, :3] = new_vertex_colors
 | 
				
			||||||
 | 
					        vertices = bs.vertices.copy()
 | 
				
			||||||
 | 
					        vertices[:, 1] = bs.vertices[:, 2]
 | 
				
			||||||
 | 
					        vertices[:, 2] = -bs.vertices[:, 1]
 | 
				
			||||||
 | 
					        bs.vertices = vertices
 | 
				
			||||||
 | 
					        model_scene.add_geometry(bs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    out_json['group'] = new_group
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with open(out_path, 'w') as json_file:
 | 
				
			||||||
 | 
					        json.dump(out_json, json_file, indent=4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    glb_path = out_path.replace('.json', '.glb')
 | 
				
			||||||
 | 
					    model_scene.export(glb_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return glb_path, out_json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser.add_argument('-c', '--config', type=str, default='./configs/infer.yml', help='Config file path')
 | 
				
			||||||
 | 
					    parser.add_argument('-ck', '--AR_ckpt', type=str, default='./ckpt/mesh-transformer.ckpt.60.pt')
 | 
				
			||||||
 | 
					    parser.add_argument('-o', '--output', type=str, default='./results/infer')
 | 
				
			||||||
 | 
					    parser.add_argument('--bs_dir', type=str, default='data/basic_shapes_norm')
 | 
				
			||||||
 | 
					    parser.add_argument('--temperature', type=float, default=0.0)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bs_names = []
 | 
				
			||||||
 | 
					    for bs_path in glob.glob(os.path.join(args.bs_dir, '*.ply')):
 | 
				
			||||||
 | 
					        bs_names.append(os.path.basename(bs_path))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with open(args.config, mode='r') as fp:
 | 
				
			||||||
 | 
					        cfg = yaml.load(fp, Loader=yaml.FullLoader)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    AR_checkpoint = torch.load(args.AR_ckpt)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    os.makedirs(args.output, exist_ok=True)
 | 
				
			||||||
 | 
					    json_result_folder = os.path.join(args.output, 'JsonResults')
 | 
				
			||||||
 | 
					    os.makedirs(json_result_folder, exist_ok=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    create_logger(Path(args.output))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dataset = create_dataset(cfg['dataset'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    transformer = create_model(cfg['model'])
 | 
				
			||||||
 | 
					    transformer.load_state_dict(AR_checkpoint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for item_i, item in tqdm(enumerate(dataset)):
 | 
				
			||||||
 | 
					        pc = item.pop('pc')
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					        item_filename = dataset.data_filename[item_i]
 | 
				
			||||||
 | 
					        if torch.cuda.is_available():
 | 
				
			||||||
 | 
					            pc = pc.cuda()
 | 
				
			||||||
 | 
					            item = torch_to(item, torch.device('cuda'))
 | 
				
			||||||
 | 
					            transformer = transformer.cuda()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        recon_primitives, mask = transformer.generate(pc=pc.unsqueeze(0), temperature=args.temperature)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out_path = os.path.join(json_result_folder, os.path.basename(item_filename).replace('.ply', '.json'))
 | 
				
			||||||
 | 
					        write_json(recon_primitives, shapename_map, out_path)
 | 
				
			||||||
							
								
								
									
										0
									
								
								primitive_anything/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										0
									
								
								primitive_anything/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
								
								
									
										51
									
								
								primitive_anything/michelangelo/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										51
									
								
								primitive_anything/michelangelo/__init__.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,51 @@
 | 
				
			|||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from omegaconf import OmegaConf
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .utils.misc import instantiate_from_config
 | 
				
			||||||
 | 
					from ..utils import default, exists
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_model():
 | 
				
			||||||
 | 
					    model_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), "shapevae-256.yaml"))
 | 
				
			||||||
 | 
					    # print(model_config)
 | 
				
			||||||
 | 
					    if hasattr(model_config, "model"):
 | 
				
			||||||
 | 
					        model_config = model_config.model
 | 
				
			||||||
 | 
					    ckpt_path = "./ckpt/shapevae-256.ckpt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
 | 
				
			||||||
 | 
					    # model = model.cuda()
 | 
				
			||||||
 | 
					    model = model.eval()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ShapeConditioner(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        dim_latent = None
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.model = load_model()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.dim_model_out = 768
 | 
				
			||||||
 | 
					        dim_latent = default(dim_latent, self.dim_model_out)
 | 
				
			||||||
 | 
					        self.dim_latent = dim_latent
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        shape = None,
 | 
				
			||||||
 | 
					        shape_embed = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        assert exists(shape) ^ exists(shape_embed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not exists(shape_embed):
 | 
				
			||||||
 | 
					            point_feature = self.model.encode_latents(shape)
 | 
				
			||||||
 | 
					            shape_latents = self.model.to_shape_latents(point_feature[:, 1:])
 | 
				
			||||||
 | 
					            shape_head = point_feature[:, 0:1]
 | 
				
			||||||
 | 
					            shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-1)
 | 
				
			||||||
 | 
					            # shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-2) # cat tmp
 | 
				
			||||||
 | 
					        return shape_head, shape_embed
 | 
				
			||||||
							
								
								
									
										1
									
								
								primitive_anything/michelangelo/graphics/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										1
									
								
								primitive_anything/michelangelo/graphics/__init__.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
							
								
								
									
										9
									
								
								primitive_anything/michelangelo/graphics/primitives/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										9
									
								
								primitive_anything/michelangelo/graphics/primitives/__init__.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,9 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .volume import generate_dense_grid_points
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .mesh import (
 | 
				
			||||||
 | 
					    MeshOutput,
 | 
				
			||||||
 | 
					    save_obj,
 | 
				
			||||||
 | 
					    savemeshtes2
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										114
									
								
								primitive_anything/michelangelo/graphics/primitives/mesh.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										114
									
								
								primitive_anything/michelangelo/graphics/primitives/mesh.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,114 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import PIL.Image
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import trimesh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def save_obj(pointnp_px3, facenp_fx3, fname):
 | 
				
			||||||
 | 
					    fid = open(fname, "w")
 | 
				
			||||||
 | 
					    write_str = ""
 | 
				
			||||||
 | 
					    for pidx, p in enumerate(pointnp_px3):
 | 
				
			||||||
 | 
					        pp = p
 | 
				
			||||||
 | 
					        write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for i, f in enumerate(facenp_fx3):
 | 
				
			||||||
 | 
					        f1 = f + 1
 | 
				
			||||||
 | 
					        write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
 | 
				
			||||||
 | 
					    fid.write(write_str)
 | 
				
			||||||
 | 
					    fid.close()
 | 
				
			||||||
 | 
					    return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
 | 
				
			||||||
 | 
					    fol, na = os.path.split(fname)
 | 
				
			||||||
 | 
					    na, _ = os.path.splitext(na)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    matname = "%s/%s.mtl" % (fol, na)
 | 
				
			||||||
 | 
					    fid = open(matname, "w")
 | 
				
			||||||
 | 
					    fid.write("newmtl material_0\n")
 | 
				
			||||||
 | 
					    fid.write("Kd 1 1 1\n")
 | 
				
			||||||
 | 
					    fid.write("Ka 0 0 0\n")
 | 
				
			||||||
 | 
					    fid.write("Ks 0.4 0.4 0.4\n")
 | 
				
			||||||
 | 
					    fid.write("Ns 10\n")
 | 
				
			||||||
 | 
					    fid.write("illum 2\n")
 | 
				
			||||||
 | 
					    fid.write("map_Kd %s.png\n" % na)
 | 
				
			||||||
 | 
					    fid.close()
 | 
				
			||||||
 | 
					    ####
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fid = open(fname, "w")
 | 
				
			||||||
 | 
					    fid.write("mtllib %s.mtl\n" % na)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for pidx, p in enumerate(pointnp_px3):
 | 
				
			||||||
 | 
					        pp = p
 | 
				
			||||||
 | 
					        fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for pidx, p in enumerate(tcoords_px2):
 | 
				
			||||||
 | 
					        pp = p
 | 
				
			||||||
 | 
					        fid.write("vt %f %f\n" % (pp[0], pp[1]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fid.write("usemtl material_0\n")
 | 
				
			||||||
 | 
					    for i, f in enumerate(facenp_fx3):
 | 
				
			||||||
 | 
					        f1 = f + 1
 | 
				
			||||||
 | 
					        f2 = facetex_fx3[i] + 1
 | 
				
			||||||
 | 
					        fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
 | 
				
			||||||
 | 
					    fid.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
 | 
				
			||||||
 | 
					        os.path.join(fol, "%s.png" % na))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MeshOutput(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 mesh_v: np.ndarray,
 | 
				
			||||||
 | 
					                 mesh_f: np.ndarray,
 | 
				
			||||||
 | 
					                 vertex_colors: Optional[np.ndarray] = None,
 | 
				
			||||||
 | 
					                 uvs: Optional[np.ndarray] = None,
 | 
				
			||||||
 | 
					                 mesh_tex_idx: Optional[np.ndarray] = None,
 | 
				
			||||||
 | 
					                 tex_map: Optional[np.ndarray] = None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.mesh_v = mesh_v
 | 
				
			||||||
 | 
					        self.mesh_f = mesh_f
 | 
				
			||||||
 | 
					        self.vertex_colors = vertex_colors
 | 
				
			||||||
 | 
					        self.uvs = uvs
 | 
				
			||||||
 | 
					        self.mesh_tex_idx = mesh_tex_idx
 | 
				
			||||||
 | 
					        self.tex_map = tex_map
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def contain_uv_texture(self):
 | 
				
			||||||
 | 
					        return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def contain_vertex_colors(self):
 | 
				
			||||||
 | 
					        return self.vertex_colors is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def export(self, fname):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.contain_uv_texture():
 | 
				
			||||||
 | 
					            savemeshtes2(
 | 
				
			||||||
 | 
					                self.mesh_v,
 | 
				
			||||||
 | 
					                self.uvs,
 | 
				
			||||||
 | 
					                self.mesh_f,
 | 
				
			||||||
 | 
					                self.mesh_tex_idx,
 | 
				
			||||||
 | 
					                self.tex_map,
 | 
				
			||||||
 | 
					                fname
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif self.contain_vertex_colors():
 | 
				
			||||||
 | 
					            mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
 | 
				
			||||||
 | 
					            mesh_obj.export(fname)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            save_obj(
 | 
				
			||||||
 | 
					                self.mesh_v,
 | 
				
			||||||
 | 
					                self.mesh_f,
 | 
				
			||||||
 | 
					                fname
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										21
									
								
								primitive_anything/michelangelo/graphics/primitives/volume.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										21
									
								
								primitive_anything/michelangelo/graphics/primitives/volume.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def generate_dense_grid_points(bbox_min: np.ndarray,
 | 
				
			||||||
 | 
					                               bbox_max: np.ndarray,
 | 
				
			||||||
 | 
					                               octree_depth: int,
 | 
				
			||||||
 | 
					                               indexing: str = "ij"):
 | 
				
			||||||
 | 
					    length = bbox_max - bbox_min
 | 
				
			||||||
 | 
					    num_cells = np.exp2(octree_depth)
 | 
				
			||||||
 | 
					    x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
 | 
				
			||||||
 | 
					    y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
 | 
				
			||||||
 | 
					    z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
 | 
				
			||||||
 | 
					    [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
 | 
				
			||||||
 | 
					    xyz = np.stack((xs, ys, zs), axis=-1)
 | 
				
			||||||
 | 
					    xyz = xyz.reshape(-1, 3)
 | 
				
			||||||
 | 
					    grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return xyz, grid_size, length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										1
									
								
								primitive_anything/michelangelo/models/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										1
									
								
								primitive_anything/michelangelo/models/__init__.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
							
								
								
									
										1
									
								
								primitive_anything/michelangelo/models/asl_diffusion/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										1
									
								
								primitive_anything/michelangelo/models/asl_diffusion/__init__.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
							
								
								
									
										483
									
								
								primitive_anything/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										483
									
								
								primitive_anything/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,483 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from omegaconf import DictConfig
 | 
				
			||||||
 | 
					from typing import List, Tuple, Dict, Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					from torch.optim import lr_scheduler
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					from pytorch_lightning.utilities import rank_zero_only
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from einops import rearrange
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from diffusers.schedulers import (
 | 
				
			||||||
 | 
					    DDPMScheduler,
 | 
				
			||||||
 | 
					    DDIMScheduler,
 | 
				
			||||||
 | 
					    KarrasVeScheduler,
 | 
				
			||||||
 | 
					    DPMSolverMultistepScheduler
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ...utils import instantiate_from_config
 | 
				
			||||||
 | 
					# from ..tsal.tsal_base import ShapeAsLatentPLModule
 | 
				
			||||||
 | 
					from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
 | 
				
			||||||
 | 
					from .inference_utils import ddim_sample
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def disabled_train(self, mode=True):
 | 
				
			||||||
 | 
					    """Overwrite model.train with this function to make sure train/eval mode
 | 
				
			||||||
 | 
					    does not change anymore."""
 | 
				
			||||||
 | 
					    return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ASLDiffuser(pl.LightningModule):
 | 
				
			||||||
 | 
					    first_stage_model: Optional[AlignedShapeAsLatentPLModule]
 | 
				
			||||||
 | 
					    # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
 | 
				
			||||||
 | 
					    model: nn.Module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 first_stage_config,
 | 
				
			||||||
 | 
					                 denoiser_cfg,
 | 
				
			||||||
 | 
					                 scheduler_cfg,
 | 
				
			||||||
 | 
					                 optimizer_cfg,
 | 
				
			||||||
 | 
					                 loss_cfg,
 | 
				
			||||||
 | 
					                 first_stage_key: str = "surface",
 | 
				
			||||||
 | 
					                 cond_stage_key: str = "image",
 | 
				
			||||||
 | 
					                 cond_stage_trainable: bool = True,
 | 
				
			||||||
 | 
					                 scale_by_std: bool = False,
 | 
				
			||||||
 | 
					                 z_scale_factor: float = 1.0,
 | 
				
			||||||
 | 
					                 ckpt_path: Optional[str] = None,
 | 
				
			||||||
 | 
					                 ignore_keys: Union[Tuple[str], List[str]] = ()):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.first_stage_key = first_stage_key
 | 
				
			||||||
 | 
					        self.cond_stage_key = cond_stage_key
 | 
				
			||||||
 | 
					        self.cond_stage_trainable = cond_stage_trainable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 1. initialize first stage. 
 | 
				
			||||||
 | 
					        # Note: the condition model contained in the first stage model.
 | 
				
			||||||
 | 
					        self.first_stage_config = first_stage_config
 | 
				
			||||||
 | 
					        self.first_stage_model = None
 | 
				
			||||||
 | 
					        # self.instantiate_first_stage(first_stage_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 2. initialize conditional stage
 | 
				
			||||||
 | 
					        # self.instantiate_cond_stage(cond_stage_config)
 | 
				
			||||||
 | 
					        self.cond_stage_model = {
 | 
				
			||||||
 | 
					            "image": self.encode_image,
 | 
				
			||||||
 | 
					            "image_unconditional_embedding": self.empty_img_cond,
 | 
				
			||||||
 | 
					            "text": self.encode_text,
 | 
				
			||||||
 | 
					            "text_unconditional_embedding": self.empty_text_cond,
 | 
				
			||||||
 | 
					            "surface": self.encode_surface,
 | 
				
			||||||
 | 
					            "surface_unconditional_embedding": self.empty_surface_cond,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 3. diffusion model
 | 
				
			||||||
 | 
					        self.model = instantiate_from_config(
 | 
				
			||||||
 | 
					            denoiser_cfg, device=None, dtype=None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.optimizer_cfg = optimizer_cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 4. scheduling strategy
 | 
				
			||||||
 | 
					        self.scheduler_cfg = scheduler_cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
 | 
				
			||||||
 | 
					        self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 5. loss configures
 | 
				
			||||||
 | 
					        self.loss_cfg = loss_cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.scale_by_std = scale_by_std
 | 
				
			||||||
 | 
					        if scale_by_std:
 | 
				
			||||||
 | 
					            self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.z_scale_factor = z_scale_factor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.ckpt_path = ckpt_path
 | 
				
			||||||
 | 
					        if ckpt_path is not None:
 | 
				
			||||||
 | 
					            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def instantiate_first_stage(self, config):
 | 
				
			||||||
 | 
					        model = instantiate_from_config(config)
 | 
				
			||||||
 | 
					        self.first_stage_model = model.eval()
 | 
				
			||||||
 | 
					        self.first_stage_model.train = disabled_train
 | 
				
			||||||
 | 
					        for param in self.first_stage_model.parameters():
 | 
				
			||||||
 | 
					            param.requires_grad = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.first_stage_model = self.first_stage_model.to(self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # def instantiate_cond_stage(self, config):
 | 
				
			||||||
 | 
					    #     if not self.cond_stage_trainable:
 | 
				
			||||||
 | 
					    #         if config == "__is_first_stage__":
 | 
				
			||||||
 | 
					    #             print("Using first stage also as cond stage.")
 | 
				
			||||||
 | 
					    #             self.cond_stage_model = self.first_stage_model
 | 
				
			||||||
 | 
					    #         elif config == "__is_unconditional__":
 | 
				
			||||||
 | 
					    #             print(f"Training {self.__class__.__name__} as an unconditional model.")
 | 
				
			||||||
 | 
					    #             self.cond_stage_model = None
 | 
				
			||||||
 | 
					    #             # self.be_unconditional = True
 | 
				
			||||||
 | 
					    #         else:
 | 
				
			||||||
 | 
					    #             model = instantiate_from_config(config)
 | 
				
			||||||
 | 
					    #             self.cond_stage_model = model.eval()
 | 
				
			||||||
 | 
					    #             self.cond_stage_model.train = disabled_train
 | 
				
			||||||
 | 
					    #             for param in self.cond_stage_model.parameters():
 | 
				
			||||||
 | 
					    #                 param.requires_grad = False
 | 
				
			||||||
 | 
					    #     else:
 | 
				
			||||||
 | 
					    #         assert config != "__is_first_stage__"
 | 
				
			||||||
 | 
					    #         assert config != "__is_unconditional__"
 | 
				
			||||||
 | 
					    #         model = instantiate_from_config(config)
 | 
				
			||||||
 | 
					    #         self.cond_stage_model = model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_from_ckpt(self, path, ignore_keys=()):
 | 
				
			||||||
 | 
					        state_dict = torch.load(path, map_location="cpu")["state_dict"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        keys = list(state_dict.keys())
 | 
				
			||||||
 | 
					        for k in keys:
 | 
				
			||||||
 | 
					            for ik in ignore_keys:
 | 
				
			||||||
 | 
					                if k.startswith(ik):
 | 
				
			||||||
 | 
					                    print("Deleting key {} from state_dict.".format(k))
 | 
				
			||||||
 | 
					                    del state_dict[k]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        missing, unexpected = self.load_state_dict(state_dict, strict=False)
 | 
				
			||||||
 | 
					        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
 | 
				
			||||||
 | 
					        if len(missing) > 0:
 | 
				
			||||||
 | 
					            print(f"Missing Keys: {missing}")
 | 
				
			||||||
 | 
					            print(f"Unexpected Keys: {unexpected}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def zero_rank(self):
 | 
				
			||||||
 | 
					        if self._trainer:
 | 
				
			||||||
 | 
					            zero_rank = self.trainer.local_rank == 0
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            zero_rank = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return zero_rank
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def configure_optimizers(self) -> Tuple[List, List]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        lr = self.learning_rate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        trainable_parameters = list(self.model.parameters())
 | 
				
			||||||
 | 
					        # if the conditional encoder is trainable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # if self.cond_stage_trainable:
 | 
				
			||||||
 | 
					        #     conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad]
 | 
				
			||||||
 | 
					        #     trainable_parameters += conditioner_params
 | 
				
			||||||
 | 
					        #     print(f"number of trainable conditional parameters: {len(conditioner_params)}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.optimizer_cfg is None:
 | 
				
			||||||
 | 
					            optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
 | 
				
			||||||
 | 
					            schedulers = []
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
 | 
				
			||||||
 | 
					            scheduler_func = instantiate_from_config(
 | 
				
			||||||
 | 
					                self.optimizer_cfg.scheduler,
 | 
				
			||||||
 | 
					                max_decay_steps=self.trainer.max_steps,
 | 
				
			||||||
 | 
					                lr_max=lr
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            scheduler = {
 | 
				
			||||||
 | 
					                "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
 | 
				
			||||||
 | 
					                "interval": "step",
 | 
				
			||||||
 | 
					                "frequency": 1
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            optimizers = [optimizer]
 | 
				
			||||||
 | 
					            schedulers = [scheduler]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return optimizers, schedulers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def encode_text(self, text):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        b = text.shape[0]
 | 
				
			||||||
 | 
					        text_tokens = rearrange(text, "b t l -> (b t) l")
 | 
				
			||||||
 | 
					        text_embed = self.first_stage_model.model.encode_text_embed(text_tokens)
 | 
				
			||||||
 | 
					        text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
 | 
				
			||||||
 | 
					        text_embed = text_embed.mean(dim=1)
 | 
				
			||||||
 | 
					        text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return text_embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def encode_image(self, img):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.first_stage_model.model.encode_image_embed(img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def encode_surface(self, surface):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def empty_text_cond(self, cond):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return torch.zeros_like(cond, device=cond.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def empty_img_cond(self, cond):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return torch.zeros_like(cond, device=cond.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def empty_surface_cond(self, cond):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return torch.zeros_like(cond, device=cond.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        z_q = self.first_stage_model.encode(surface, sample_posterior)
 | 
				
			||||||
 | 
					        z_q = self.z_scale_factor * z_q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return z_q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        z_q = 1. / self.z_scale_factor * z_q
 | 
				
			||||||
 | 
					        latents = self.first_stage_model.decode(z_q, **kwargs)
 | 
				
			||||||
 | 
					        return latents
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @rank_zero_only
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def on_train_batch_start(self, batch, batch_idx):
 | 
				
			||||||
 | 
					        # only for very first batch
 | 
				
			||||||
 | 
					        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
 | 
				
			||||||
 | 
					                and batch_idx == 0 and self.ckpt_path is None:
 | 
				
			||||||
 | 
					            # set rescale weight to 1./std of encodings
 | 
				
			||||||
 | 
					            print("### USING STD-RESCALING ###")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            z_q = self.encode_first_stage(batch[self.first_stage_key])
 | 
				
			||||||
 | 
					            z = z_q.detach()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            del self.z_scale_factor
 | 
				
			||||||
 | 
					            self.register_buffer("z_scale_factor", 1. / z.flatten().std())
 | 
				
			||||||
 | 
					            print(f"setting self.z_scale_factor to {self.z_scale_factor}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            print("### USING STD-RESCALING ###")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def compute_loss(self, model_outputs, split):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            model_outputs (dict):
 | 
				
			||||||
 | 
					                - x_0:
 | 
				
			||||||
 | 
					                - noise:
 | 
				
			||||||
 | 
					                - noise_prior:
 | 
				
			||||||
 | 
					                - noise_pred:
 | 
				
			||||||
 | 
					                - noise_pred_prior:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            split (str):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pred = model_outputs["pred"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.noise_scheduler.prediction_type == "epsilon":
 | 
				
			||||||
 | 
					            target = model_outputs["noise"]
 | 
				
			||||||
 | 
					        elif self.noise_scheduler.prediction_type == "sample":
 | 
				
			||||||
 | 
					            target = model_outputs["x_0"]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.loss_cfg.loss_type == "l1":
 | 
				
			||||||
 | 
					            simple = F.l1_loss(pred, target, reduction="mean")
 | 
				
			||||||
 | 
					        elif self.loss_cfg.loss_type in ["mse", "l2"]:
 | 
				
			||||||
 | 
					            simple = F.mse_loss(pred, target, reduction="mean")
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        total_loss = simple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loss_dict = {
 | 
				
			||||||
 | 
					            f"{split}/total_loss": total_loss.clone().detach(),
 | 
				
			||||||
 | 
					            f"{split}/simple": simple.detach(),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return total_loss, loss_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, batch):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            batch:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.first_stage_model is None:
 | 
				
			||||||
 | 
					            self.instantiate_first_stage(self.first_stage_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        latents = self.encode_first_stage(batch[self.first_stage_key])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1
 | 
				
			||||||
 | 
					        conditions = conditions * mask.to(conditions)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Sample noise that we"ll add to the latents
 | 
				
			||||||
 | 
					        # [batch_size, n_token, latent_dim]
 | 
				
			||||||
 | 
					        noise = torch.randn_like(latents)
 | 
				
			||||||
 | 
					        bs = latents.shape[0]
 | 
				
			||||||
 | 
					        # Sample a random timestep for each motion
 | 
				
			||||||
 | 
					        timesteps = torch.randint(
 | 
				
			||||||
 | 
					            0,
 | 
				
			||||||
 | 
					            self.noise_scheduler.config.num_train_timesteps,
 | 
				
			||||||
 | 
					            (bs,),
 | 
				
			||||||
 | 
					            device=latents.device,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        timesteps = timesteps.long()
 | 
				
			||||||
 | 
					        # Add noise to the latents according to the noise magnitude at each timestep
 | 
				
			||||||
 | 
					        noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # diffusion model forward
 | 
				
			||||||
 | 
					        noise_pred = self.model(noisy_z, timesteps, conditions)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        diffusion_outputs = {
 | 
				
			||||||
 | 
					            "x_0": noisy_z,
 | 
				
			||||||
 | 
					            "noise": noise,
 | 
				
			||||||
 | 
					            "pred": noise_pred
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return diffusion_outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
 | 
				
			||||||
 | 
					                      batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            batch (dict): the batch sample, and it contains:
 | 
				
			||||||
 | 
					                - surface (torch.FloatTensor):
 | 
				
			||||||
 | 
					                - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
 | 
				
			||||||
 | 
					                - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
 | 
				
			||||||
 | 
					                - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
 | 
				
			||||||
 | 
					                - text (list of str):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            batch_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            optimizer_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            loss (torch.FloatTensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        diffusion_outputs = self(batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
 | 
				
			||||||
 | 
					        self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validation_step(self, batch: Dict[str, torch.FloatTensor],
 | 
				
			||||||
 | 
					                        batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            batch (dict): the batch sample, and it contains:
 | 
				
			||||||
 | 
					                - surface_pc (torch.FloatTensor): [n_pts, 4]
 | 
				
			||||||
 | 
					                - surface_feats (torch.FloatTensor): [n_pts, c]
 | 
				
			||||||
 | 
					                - text (list of str):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            batch_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            optimizer_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            loss (torch.FloatTensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        diffusion_outputs = self(batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
 | 
				
			||||||
 | 
					        self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def sample(self,
 | 
				
			||||||
 | 
					               batch: Dict[str, Union[torch.FloatTensor, List[str]]],
 | 
				
			||||||
 | 
					               sample_times: int = 1,
 | 
				
			||||||
 | 
					               steps: Optional[int] = None,
 | 
				
			||||||
 | 
					               guidance_scale: Optional[float] = None,
 | 
				
			||||||
 | 
					               eta: float = 0.0,
 | 
				
			||||||
 | 
					               return_intermediates: bool = False, **kwargs):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.first_stage_model is None:
 | 
				
			||||||
 | 
					            self.instantiate_first_stage(self.first_stage_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if steps is None:
 | 
				
			||||||
 | 
					            steps = self.scheduler_cfg.num_inference_steps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if guidance_scale is None:
 | 
				
			||||||
 | 
					            guidance_scale = self.scheduler_cfg.guidance_scale
 | 
				
			||||||
 | 
					        do_classifier_free_guidance = guidance_scale > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # conditional encode
 | 
				
			||||||
 | 
					        xc = batch[self.cond_stage_key]
 | 
				
			||||||
 | 
					        # cond = self.cond_stage_model[self.cond_stage_key](xc)
 | 
				
			||||||
 | 
					        cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if do_classifier_free_guidance:
 | 
				
			||||||
 | 
					            """
 | 
				
			||||||
 | 
					            Note: There are two kinds of uncond for text. 
 | 
				
			||||||
 | 
					            1: using "" as uncond text; (in SAL diffusion)
 | 
				
			||||||
 | 
					            2: zeros_like(cond) as uncond text; (in MDM)
 | 
				
			||||||
 | 
					            """
 | 
				
			||||||
 | 
					            # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
 | 
				
			||||||
 | 
					            un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond)
 | 
				
			||||||
 | 
					            # un_cond = torch.zeros_like(cond, device=cond.device)
 | 
				
			||||||
 | 
					            cond = torch.cat([un_cond, cond], dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        outputs = []
 | 
				
			||||||
 | 
					        latents = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not return_intermediates:
 | 
				
			||||||
 | 
					            for _ in range(sample_times):
 | 
				
			||||||
 | 
					                sample_loop = ddim_sample(
 | 
				
			||||||
 | 
					                    self.denoise_scheduler,
 | 
				
			||||||
 | 
					                    self.model,
 | 
				
			||||||
 | 
					                    shape=self.first_stage_model.latent_shape,
 | 
				
			||||||
 | 
					                    cond=cond,
 | 
				
			||||||
 | 
					                    steps=steps,
 | 
				
			||||||
 | 
					                    guidance_scale=guidance_scale,
 | 
				
			||||||
 | 
					                    do_classifier_free_guidance=do_classifier_free_guidance,
 | 
				
			||||||
 | 
					                    device=self.device,
 | 
				
			||||||
 | 
					                    eta=eta,
 | 
				
			||||||
 | 
					                    disable_prog=not self.zero_rank
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                for sample, t in sample_loop:
 | 
				
			||||||
 | 
					                    latents = sample
 | 
				
			||||||
 | 
					                outputs.append(self.decode_first_stage(latents, **kwargs))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            sample_loop = ddim_sample(
 | 
				
			||||||
 | 
					                self.denoise_scheduler,
 | 
				
			||||||
 | 
					                self.model,
 | 
				
			||||||
 | 
					                shape=self.first_stage_model.latent_shape,
 | 
				
			||||||
 | 
					                cond=cond,
 | 
				
			||||||
 | 
					                steps=steps,
 | 
				
			||||||
 | 
					                guidance_scale=guidance_scale,
 | 
				
			||||||
 | 
					                do_classifier_free_guidance=do_classifier_free_guidance,
 | 
				
			||||||
 | 
					                device=self.device,
 | 
				
			||||||
 | 
					                eta=eta,
 | 
				
			||||||
 | 
					                disable_prog=not self.zero_rank
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            iter_size = steps // sample_times
 | 
				
			||||||
 | 
					            i = 0
 | 
				
			||||||
 | 
					            for sample, t in sample_loop:
 | 
				
			||||||
 | 
					                latents = sample
 | 
				
			||||||
 | 
					                if i % iter_size == 0 or i == steps - 1:
 | 
				
			||||||
 | 
					                    outputs.append(self.decode_first_stage(latents, **kwargs))
 | 
				
			||||||
 | 
					                i += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
							
								
								
									
										104
									
								
								primitive_anything/michelangelo/models/asl_diffusion/asl_udt.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										104
									
								
								primitive_anything/michelangelo/models/asl_diffusion/asl_udt.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,104 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from diffusers.models.embeddings import Timesteps
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..modules.transformer_blocks import MLP
 | 
				
			||||||
 | 
					from ..modules.diffusion_transformer import UNetDiffusionTransformer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ConditionalASLUDTDenoiser(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 device: Optional[torch.device],
 | 
				
			||||||
 | 
					                 dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					                 input_channels: int,
 | 
				
			||||||
 | 
					                 output_channels: int,
 | 
				
			||||||
 | 
					                 n_ctx: int,
 | 
				
			||||||
 | 
					                 width: int,
 | 
				
			||||||
 | 
					                 layers: int,
 | 
				
			||||||
 | 
					                 heads: int,
 | 
				
			||||||
 | 
					                 context_dim: int,
 | 
				
			||||||
 | 
					                 context_ln: bool = True,
 | 
				
			||||||
 | 
					                 skip_ln: bool = False,
 | 
				
			||||||
 | 
					                 init_scale: float = 0.25,
 | 
				
			||||||
 | 
					                 flip_sin_to_cos: bool = False,
 | 
				
			||||||
 | 
					                 use_checkpoint: bool = False):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.use_checkpoint = use_checkpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        init_scale = init_scale * math.sqrt(1.0 / width)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.backbone = UNetDiffusionTransformer(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            n_ctx=n_ctx,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            layers=layers,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            skip_ln=skip_ln,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            use_checkpoint=use_checkpoint
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # timestep embedding
 | 
				
			||||||
 | 
					        self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
 | 
				
			||||||
 | 
					        self.time_proj = MLP(
 | 
				
			||||||
 | 
					            device=device, dtype=dtype, width=width, init_scale=init_scale
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.context_embed = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.LayerNorm(context_dim, device=device, dtype=dtype),
 | 
				
			||||||
 | 
					            nn.Linear(context_dim, width, device=device, dtype=dtype),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if context_ln:
 | 
				
			||||||
 | 
					            self.context_embed = nn.Sequential(
 | 
				
			||||||
 | 
					                nn.LayerNorm(context_dim, device=device, dtype=dtype),
 | 
				
			||||||
 | 
					                nn.Linear(context_dim, width, device=device, dtype=dtype),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self,
 | 
				
			||||||
 | 
					                model_input: torch.FloatTensor,
 | 
				
			||||||
 | 
					                timestep: torch.LongTensor,
 | 
				
			||||||
 | 
					                context: torch.FloatTensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        r"""
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            model_input (torch.FloatTensor): [bs, n_data, c]
 | 
				
			||||||
 | 
					            timestep (torch.LongTensor): [bs,]
 | 
				
			||||||
 | 
					            context (torch.FloatTensor): [bs, context_tokens, c]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            sample (torch.FloatTensor): [bs, n_data, c]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        _, n_data, _ = model_input.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 1. time
 | 
				
			||||||
 | 
					        t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 2. conditions projector
 | 
				
			||||||
 | 
					        context = self.context_embed(context)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 3. denoiser
 | 
				
			||||||
 | 
					        x = self.input_proj(model_input)
 | 
				
			||||||
 | 
					        x = torch.cat([t_emb, context, x], dim=1)
 | 
				
			||||||
 | 
					        x = self.backbone(x)
 | 
				
			||||||
 | 
					        x = self.ln_post(x)
 | 
				
			||||||
 | 
					        x = x[:, -n_data:]
 | 
				
			||||||
 | 
					        sample = self.output_proj(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return sample
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										13
									
								
								primitive_anything/michelangelo/models/asl_diffusion/base.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										13
									
								
								primitive_anything/michelangelo/models/asl_diffusion/base.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaseDenoiser(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x, t, context):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
@ -0,0 +1,393 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from omegaconf import DictConfig
 | 
				
			||||||
 | 
					from typing import List, Tuple, Dict, Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					from torch.optim import lr_scheduler
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					from pytorch_lightning.utilities import rank_zero_only
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from diffusers.schedulers import (
 | 
				
			||||||
 | 
					    DDPMScheduler,
 | 
				
			||||||
 | 
					    DDIMScheduler,
 | 
				
			||||||
 | 
					    KarrasVeScheduler,
 | 
				
			||||||
 | 
					    DPMSolverMultistepScheduler
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ...utils import instantiate_from_config
 | 
				
			||||||
 | 
					from ..tsal.tsal_base import AlignedShapeAsLatentPLModule
 | 
				
			||||||
 | 
					from .inference_utils import ddim_sample
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def disabled_train(self, mode=True):
 | 
				
			||||||
 | 
					    """Overwrite model.train with this function to make sure train/eval mode
 | 
				
			||||||
 | 
					    does not change anymore."""
 | 
				
			||||||
 | 
					    return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ClipASLDiffuser(pl.LightningModule):
 | 
				
			||||||
 | 
					    first_stage_model: Optional[AlignedShapeAsLatentPLModule]
 | 
				
			||||||
 | 
					    cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
 | 
				
			||||||
 | 
					    model: nn.Module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 first_stage_config,
 | 
				
			||||||
 | 
					                 cond_stage_config,
 | 
				
			||||||
 | 
					                 denoiser_cfg,
 | 
				
			||||||
 | 
					                 scheduler_cfg,
 | 
				
			||||||
 | 
					                 optimizer_cfg,
 | 
				
			||||||
 | 
					                 loss_cfg,
 | 
				
			||||||
 | 
					                 first_stage_key: str = "surface",
 | 
				
			||||||
 | 
					                 cond_stage_key: str = "image",
 | 
				
			||||||
 | 
					                 scale_by_std: bool = False,
 | 
				
			||||||
 | 
					                 z_scale_factor: float = 1.0,
 | 
				
			||||||
 | 
					                 ckpt_path: Optional[str] = None,
 | 
				
			||||||
 | 
					                 ignore_keys: Union[Tuple[str], List[str]] = ()):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.first_stage_key = first_stage_key
 | 
				
			||||||
 | 
					        self.cond_stage_key = cond_stage_key
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 1. lazy initialize first stage
 | 
				
			||||||
 | 
					        self.instantiate_first_stage(first_stage_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 2. initialize conditional stage
 | 
				
			||||||
 | 
					        self.instantiate_cond_stage(cond_stage_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 3. diffusion model
 | 
				
			||||||
 | 
					        self.model = instantiate_from_config(
 | 
				
			||||||
 | 
					            denoiser_cfg, device=None, dtype=None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.optimizer_cfg = optimizer_cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 4. scheduling strategy
 | 
				
			||||||
 | 
					        self.scheduler_cfg = scheduler_cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
 | 
				
			||||||
 | 
					        self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 5. loss configures
 | 
				
			||||||
 | 
					        self.loss_cfg = loss_cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.scale_by_std = scale_by_std
 | 
				
			||||||
 | 
					        if scale_by_std:
 | 
				
			||||||
 | 
					            self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.z_scale_factor = z_scale_factor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.ckpt_path = ckpt_path
 | 
				
			||||||
 | 
					        if ckpt_path is not None:
 | 
				
			||||||
 | 
					            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def instantiate_non_trainable_model(self, config):
 | 
				
			||||||
 | 
					        model = instantiate_from_config(config)
 | 
				
			||||||
 | 
					        model = model.eval()
 | 
				
			||||||
 | 
					        model.train = disabled_train
 | 
				
			||||||
 | 
					        for param in model.parameters():
 | 
				
			||||||
 | 
					            param.requires_grad = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def instantiate_first_stage(self, first_stage_config):
 | 
				
			||||||
 | 
					        self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config)
 | 
				
			||||||
 | 
					        self.first_stage_model.set_shape_model_only()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def instantiate_cond_stage(self, cond_stage_config):
 | 
				
			||||||
 | 
					        self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_from_ckpt(self, path, ignore_keys=()):
 | 
				
			||||||
 | 
					        state_dict = torch.load(path, map_location="cpu")["state_dict"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        keys = list(state_dict.keys())
 | 
				
			||||||
 | 
					        for k in keys:
 | 
				
			||||||
 | 
					            for ik in ignore_keys:
 | 
				
			||||||
 | 
					                if k.startswith(ik):
 | 
				
			||||||
 | 
					                    print("Deleting key {} from state_dict.".format(k))
 | 
				
			||||||
 | 
					                    del state_dict[k]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        missing, unexpected = self.load_state_dict(state_dict, strict=False)
 | 
				
			||||||
 | 
					        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
 | 
				
			||||||
 | 
					        if len(missing) > 0:
 | 
				
			||||||
 | 
					            print(f"Missing Keys: {missing}")
 | 
				
			||||||
 | 
					            print(f"Unexpected Keys: {unexpected}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def zero_rank(self):
 | 
				
			||||||
 | 
					        if self._trainer:
 | 
				
			||||||
 | 
					            zero_rank = self.trainer.local_rank == 0
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            zero_rank = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return zero_rank
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def configure_optimizers(self) -> Tuple[List, List]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        lr = self.learning_rate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        trainable_parameters = list(self.model.parameters())
 | 
				
			||||||
 | 
					        if self.optimizer_cfg is None:
 | 
				
			||||||
 | 
					            optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
 | 
				
			||||||
 | 
					            schedulers = []
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
 | 
				
			||||||
 | 
					            scheduler_func = instantiate_from_config(
 | 
				
			||||||
 | 
					                self.optimizer_cfg.scheduler,
 | 
				
			||||||
 | 
					                max_decay_steps=self.trainer.max_steps,
 | 
				
			||||||
 | 
					                lr_max=lr
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            scheduler = {
 | 
				
			||||||
 | 
					                "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
 | 
				
			||||||
 | 
					                "interval": "step",
 | 
				
			||||||
 | 
					                "frequency": 1
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            optimizers = [optimizer]
 | 
				
			||||||
 | 
					            schedulers = [scheduler]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return optimizers, schedulers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        z_q = self.first_stage_model.encode(surface, sample_posterior)
 | 
				
			||||||
 | 
					        z_q = self.z_scale_factor * z_q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return z_q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        z_q = 1. / self.z_scale_factor * z_q
 | 
				
			||||||
 | 
					        latents = self.first_stage_model.decode(z_q, **kwargs)
 | 
				
			||||||
 | 
					        return latents
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @rank_zero_only
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def on_train_batch_start(self, batch, batch_idx):
 | 
				
			||||||
 | 
					        # only for very first batch
 | 
				
			||||||
 | 
					        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
 | 
				
			||||||
 | 
					                and batch_idx == 0 and self.ckpt_path is None:
 | 
				
			||||||
 | 
					            # set rescale weight to 1./std of encodings
 | 
				
			||||||
 | 
					            print("### USING STD-RESCALING ###")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            z_q = self.encode_first_stage(batch[self.first_stage_key])
 | 
				
			||||||
 | 
					            z = z_q.detach()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            del self.z_scale_factor
 | 
				
			||||||
 | 
					            self.register_buffer("z_scale_factor", 1. / z.flatten().std())
 | 
				
			||||||
 | 
					            print(f"setting self.z_scale_factor to {self.z_scale_factor}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            print("### USING STD-RESCALING ###")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def compute_loss(self, model_outputs, split):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            model_outputs (dict):
 | 
				
			||||||
 | 
					                - x_0:
 | 
				
			||||||
 | 
					                - noise:
 | 
				
			||||||
 | 
					                - noise_prior:
 | 
				
			||||||
 | 
					                - noise_pred:
 | 
				
			||||||
 | 
					                - noise_pred_prior:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            split (str):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pred = model_outputs["pred"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.noise_scheduler.prediction_type == "epsilon":
 | 
				
			||||||
 | 
					            target = model_outputs["noise"]
 | 
				
			||||||
 | 
					        elif self.noise_scheduler.prediction_type == "sample":
 | 
				
			||||||
 | 
					            target = model_outputs["x_0"]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.loss_cfg.loss_type == "l1":
 | 
				
			||||||
 | 
					            simple = F.l1_loss(pred, target, reduction="mean")
 | 
				
			||||||
 | 
					        elif self.loss_cfg.loss_type in ["mse", "l2"]:
 | 
				
			||||||
 | 
					            simple = F.mse_loss(pred, target, reduction="mean")
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        total_loss = simple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loss_dict = {
 | 
				
			||||||
 | 
					            f"{split}/total_loss": total_loss.clone().detach(),
 | 
				
			||||||
 | 
					            f"{split}/simple": simple.detach(),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return total_loss, loss_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, batch):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            batch:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        latents = self.encode_first_stage(batch[self.first_stage_key])
 | 
				
			||||||
 | 
					        conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Sample noise that we"ll add to the latents
 | 
				
			||||||
 | 
					        # [batch_size, n_token, latent_dim]
 | 
				
			||||||
 | 
					        noise = torch.randn_like(latents)
 | 
				
			||||||
 | 
					        bs = latents.shape[0]
 | 
				
			||||||
 | 
					        # Sample a random timestep for each motion
 | 
				
			||||||
 | 
					        timesteps = torch.randint(
 | 
				
			||||||
 | 
					            0,
 | 
				
			||||||
 | 
					            self.noise_scheduler.config.num_train_timesteps,
 | 
				
			||||||
 | 
					            (bs,),
 | 
				
			||||||
 | 
					            device=latents.device,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        timesteps = timesteps.long()
 | 
				
			||||||
 | 
					        # Add noise to the latents according to the noise magnitude at each timestep
 | 
				
			||||||
 | 
					        noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # diffusion model forward
 | 
				
			||||||
 | 
					        noise_pred = self.model(noisy_z, timesteps, conditions)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        diffusion_outputs = {
 | 
				
			||||||
 | 
					            "x_0": noisy_z,
 | 
				
			||||||
 | 
					            "noise": noise,
 | 
				
			||||||
 | 
					            "pred": noise_pred
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return diffusion_outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
 | 
				
			||||||
 | 
					                      batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            batch (dict): the batch sample, and it contains:
 | 
				
			||||||
 | 
					                - surface (torch.FloatTensor):
 | 
				
			||||||
 | 
					                - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
 | 
				
			||||||
 | 
					                - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
 | 
				
			||||||
 | 
					                - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
 | 
				
			||||||
 | 
					                - text (list of str):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            batch_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            optimizer_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            loss (torch.FloatTensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        diffusion_outputs = self(batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
 | 
				
			||||||
 | 
					        self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validation_step(self, batch: Dict[str, torch.FloatTensor],
 | 
				
			||||||
 | 
					                        batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            batch (dict): the batch sample, and it contains:
 | 
				
			||||||
 | 
					                - surface_pc (torch.FloatTensor): [n_pts, 4]
 | 
				
			||||||
 | 
					                - surface_feats (torch.FloatTensor): [n_pts, c]
 | 
				
			||||||
 | 
					                - text (list of str):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            batch_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            optimizer_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            loss (torch.FloatTensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        diffusion_outputs = self(batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
 | 
				
			||||||
 | 
					        self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def sample(self,
 | 
				
			||||||
 | 
					               batch: Dict[str, Union[torch.FloatTensor, List[str]]],
 | 
				
			||||||
 | 
					               sample_times: int = 1,
 | 
				
			||||||
 | 
					               steps: Optional[int] = None,
 | 
				
			||||||
 | 
					               guidance_scale: Optional[float] = None,
 | 
				
			||||||
 | 
					               eta: float = 0.0,
 | 
				
			||||||
 | 
					               return_intermediates: bool = False, **kwargs):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if steps is None:
 | 
				
			||||||
 | 
					            steps = self.scheduler_cfg.num_inference_steps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if guidance_scale is None:
 | 
				
			||||||
 | 
					            guidance_scale = self.scheduler_cfg.guidance_scale
 | 
				
			||||||
 | 
					        do_classifier_free_guidance = guidance_scale > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # conditional encode
 | 
				
			||||||
 | 
					        xc = batch[self.cond_stage_key]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # print(self.first_stage_model.device, self.cond_stage_model.device, self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cond = self.cond_stage_model(xc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if do_classifier_free_guidance:
 | 
				
			||||||
 | 
					            un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
 | 
				
			||||||
 | 
					            cond = torch.cat([un_cond, cond], dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        outputs = []
 | 
				
			||||||
 | 
					        latents = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not return_intermediates:
 | 
				
			||||||
 | 
					            for _ in range(sample_times):
 | 
				
			||||||
 | 
					                sample_loop = ddim_sample(
 | 
				
			||||||
 | 
					                    self.denoise_scheduler,
 | 
				
			||||||
 | 
					                    self.model,
 | 
				
			||||||
 | 
					                    shape=self.first_stage_model.latent_shape,
 | 
				
			||||||
 | 
					                    cond=cond,
 | 
				
			||||||
 | 
					                    steps=steps,
 | 
				
			||||||
 | 
					                    guidance_scale=guidance_scale,
 | 
				
			||||||
 | 
					                    do_classifier_free_guidance=do_classifier_free_guidance,
 | 
				
			||||||
 | 
					                    device=self.device,
 | 
				
			||||||
 | 
					                    eta=eta,
 | 
				
			||||||
 | 
					                    disable_prog=not self.zero_rank
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                for sample, t in sample_loop:
 | 
				
			||||||
 | 
					                    latents = sample
 | 
				
			||||||
 | 
					                outputs.append(self.decode_first_stage(latents, **kwargs))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            sample_loop = ddim_sample(
 | 
				
			||||||
 | 
					                self.denoise_scheduler,
 | 
				
			||||||
 | 
					                self.model,
 | 
				
			||||||
 | 
					                shape=self.first_stage_model.latent_shape,
 | 
				
			||||||
 | 
					                cond=cond,
 | 
				
			||||||
 | 
					                steps=steps,
 | 
				
			||||||
 | 
					                guidance_scale=guidance_scale,
 | 
				
			||||||
 | 
					                do_classifier_free_guidance=do_classifier_free_guidance,
 | 
				
			||||||
 | 
					                device=self.device,
 | 
				
			||||||
 | 
					                eta=eta,
 | 
				
			||||||
 | 
					                disable_prog=not self.zero_rank
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            iter_size = steps // sample_times
 | 
				
			||||||
 | 
					            i = 0
 | 
				
			||||||
 | 
					            for sample, t in sample_loop:
 | 
				
			||||||
 | 
					                latents = sample
 | 
				
			||||||
 | 
					                if i % iter_size == 0 or i == steps - 1:
 | 
				
			||||||
 | 
					                    outputs.append(self.decode_first_stage(latents, **kwargs))
 | 
				
			||||||
 | 
					                i += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
							
								
								
									
										80
									
								
								primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										80
									
								
								primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,80 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					from typing import Tuple, List, Union, Optional
 | 
				
			||||||
 | 
					from diffusers.schedulers import DDIMScheduler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = ["ddim_sample"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def ddim_sample(ddim_scheduler: DDIMScheduler,
 | 
				
			||||||
 | 
					                diffusion_model: torch.nn.Module,
 | 
				
			||||||
 | 
					                shape: Union[List[int], Tuple[int]],
 | 
				
			||||||
 | 
					                cond: torch.FloatTensor,
 | 
				
			||||||
 | 
					                steps: int,
 | 
				
			||||||
 | 
					                eta: float = 0.0,
 | 
				
			||||||
 | 
					                guidance_scale: float = 3.0,
 | 
				
			||||||
 | 
					                do_classifier_free_guidance: bool = True,
 | 
				
			||||||
 | 
					                generator: Optional[torch.Generator] = None,
 | 
				
			||||||
 | 
					                device: torch.device = "cuda:0",
 | 
				
			||||||
 | 
					                disable_prog: bool = True):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert steps > 0, f"{steps} must > 0."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # init latents
 | 
				
			||||||
 | 
					    bsz = cond.shape[0]
 | 
				
			||||||
 | 
					    if do_classifier_free_guidance:
 | 
				
			||||||
 | 
					        bsz = bsz // 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    latents = torch.randn(
 | 
				
			||||||
 | 
					        (bsz, *shape),
 | 
				
			||||||
 | 
					        generator=generator,
 | 
				
			||||||
 | 
					        device=cond.device,
 | 
				
			||||||
 | 
					        dtype=cond.dtype,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    # scale the initial noise by the standard deviation required by the scheduler
 | 
				
			||||||
 | 
					    latents = latents * ddim_scheduler.init_noise_sigma
 | 
				
			||||||
 | 
					    # set timesteps
 | 
				
			||||||
 | 
					    ddim_scheduler.set_timesteps(steps)
 | 
				
			||||||
 | 
					    timesteps = ddim_scheduler.timesteps.to(device)
 | 
				
			||||||
 | 
					    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
 | 
				
			||||||
 | 
					    # eta (η) is only used with the DDIMScheduler, and between [0, 1]
 | 
				
			||||||
 | 
					    extra_step_kwargs = {
 | 
				
			||||||
 | 
					        "eta": eta,
 | 
				
			||||||
 | 
					        "generator": generator
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # reverse
 | 
				
			||||||
 | 
					    for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
 | 
				
			||||||
 | 
					        # expand the latents if we are doing classifier free guidance
 | 
				
			||||||
 | 
					        latent_model_input = (
 | 
				
			||||||
 | 
					            torch.cat([latents] * 2)
 | 
				
			||||||
 | 
					            if do_classifier_free_guidance
 | 
				
			||||||
 | 
					            else latents
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # latent_model_input = scheduler.scale_model_input(latent_model_input, t)
 | 
				
			||||||
 | 
					        # predict the noise residual
 | 
				
			||||||
 | 
					        timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
 | 
				
			||||||
 | 
					        timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
 | 
				
			||||||
 | 
					        noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # perform guidance
 | 
				
			||||||
 | 
					        if do_classifier_free_guidance:
 | 
				
			||||||
 | 
					            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
 | 
				
			||||||
 | 
					            noise_pred = noise_pred_uncond + guidance_scale * (
 | 
				
			||||||
 | 
					                    noise_pred_text - noise_pred_uncond
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            # text_embeddings_for_guidance = encoder_hidden_states.chunk(
 | 
				
			||||||
 | 
					            #     2)[1] if do_classifier_free_guidance else encoder_hidden_states
 | 
				
			||||||
 | 
					        # compute the previous noisy sample x_t -> x_t-1
 | 
				
			||||||
 | 
					        latents = ddim_scheduler.step(
 | 
				
			||||||
 | 
					            noise_pred, t, latents, **extra_step_kwargs
 | 
				
			||||||
 | 
					        ).prev_sample
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        yield latents, t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def karra_sample():
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
							
								
								
									
										3
									
								
								primitive_anything/michelangelo/models/conditional_encoders/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										3
									
								
								primitive_anything/michelangelo/models/conditional_encoders/__init__.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .clip import CLIPEncoder
 | 
				
			||||||
							
								
								
									
										89
									
								
								primitive_anything/michelangelo/models/conditional_encoders/clip.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										89
									
								
								primitive_anything/michelangelo/models/conditional_encoders/clip.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,89 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					from torchvision.transforms import Normalize
 | 
				
			||||||
 | 
					from transformers import CLIPModel, CLIPTokenizer
 | 
				
			||||||
 | 
					from transformers.utils import ModelOutput
 | 
				
			||||||
 | 
					from typing import Iterable, Optional, Union, List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class CLIPEmbedOutput(ModelOutput):
 | 
				
			||||||
 | 
					    last_hidden_state: torch.FloatTensor = None
 | 
				
			||||||
 | 
					    pooler_output: torch.FloatTensor = None
 | 
				
			||||||
 | 
					    embeds: torch.FloatTensor = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CLIPEncoder(torch.nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, model_path="openai/clip-vit-base-patch32"):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Load the CLIP model and processor
 | 
				
			||||||
 | 
					        self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
 | 
				
			||||||
 | 
					        self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
 | 
				
			||||||
 | 
					        self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.model.training = False
 | 
				
			||||||
 | 
					        for p in self.model.parameters():
 | 
				
			||||||
 | 
					            p.requires_grad = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def encode_image(self, images: Iterable[Optional[ImageType]]):
 | 
				
			||||||
 | 
					        pixel_values = self.image_preprocess(images)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        vision_outputs = self.model.vision_model(pixel_values=pixel_values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pooler_output = vision_outputs[1]  # pooled_output
 | 
				
			||||||
 | 
					        image_features = self.model.visual_projection(pooler_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        visual_embeds = CLIPEmbedOutput(
 | 
				
			||||||
 | 
					            last_hidden_state=vision_outputs.last_hidden_state,
 | 
				
			||||||
 | 
					            pooler_output=pooler_output,
 | 
				
			||||||
 | 
					            embeds=image_features
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return visual_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def encode_text(self, texts: List[str]):
 | 
				
			||||||
 | 
					        text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        text_outputs = self.model.text_model(input_ids=text_inputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pooler_output = text_outputs[1]  # pooled_output
 | 
				
			||||||
 | 
					        text_features = self.model.text_projection(pooler_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        text_embeds = CLIPEmbedOutput(
 | 
				
			||||||
 | 
					            last_hidden_state=text_outputs.last_hidden_state,
 | 
				
			||||||
 | 
					            pooler_output=pooler_output,
 | 
				
			||||||
 | 
					            embeds=text_features
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return text_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self,
 | 
				
			||||||
 | 
					                images: Iterable[Optional[ImageType]],
 | 
				
			||||||
 | 
					                texts: List[str]):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        visual_embeds = self.encode_image(images)
 | 
				
			||||||
 | 
					        text_embeds = self.encode_text(texts)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return visual_embeds, text_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										562
									
								
								primitive_anything/michelangelo/models/conditional_encoders/encoder_factory.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										562
									
								
								primitive_anything/michelangelo/models/conditional_encoders/encoder_factory.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,562 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from torchvision import transforms
 | 
				
			||||||
 | 
					from transformers import CLIPModel, CLIPTokenizer
 | 
				
			||||||
 | 
					from collections import OrderedDict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ...data.transforms import RandomResize
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AbstractEncoder(nn.Module):
 | 
				
			||||||
 | 
					    embedding_dim: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ClassEmbedder(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, embed_dim, n_classes=1000, key="class"):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.key = key
 | 
				
			||||||
 | 
					        self.embedding = nn.Embedding(n_classes, embed_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, batch, key=None):
 | 
				
			||||||
 | 
					        if key is None:
 | 
				
			||||||
 | 
					            key = self.key
 | 
				
			||||||
 | 
					        # this is for use in crossattn
 | 
				
			||||||
 | 
					        c = batch[key][:, None]
 | 
				
			||||||
 | 
					        c = self.embedding(c)
 | 
				
			||||||
 | 
					        return c
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FrozenCLIPTextEmbedder(AbstractEncoder):
 | 
				
			||||||
 | 
					    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        version="openai/clip-vit-large-patch14",
 | 
				
			||||||
 | 
					        tokenizer_version=None,
 | 
				
			||||||
 | 
					        device="cuda",
 | 
				
			||||||
 | 
					        max_length=77,
 | 
				
			||||||
 | 
					        zero_embedding_radio: float = 0.1,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.device = device
 | 
				
			||||||
 | 
					        self.max_length = max_length
 | 
				
			||||||
 | 
					        self.zero_embedding_radio = zero_embedding_radio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.clip_dict = OrderedDict()
 | 
				
			||||||
 | 
					        self.clip_name = os.path.split(version)[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        transformer = CLIPModel.from_pretrained(version).text_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for param in transformer.parameters():
 | 
				
			||||||
 | 
					            param.requires_grad = False
 | 
				
			||||||
 | 
					        self.clip_dict[self.clip_name] = transformer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._move_flag = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def clip(self):
 | 
				
			||||||
 | 
					        return self.clip_dict[self.clip_name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def move(self):
 | 
				
			||||||
 | 
					        if self._move_flag:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
 | 
				
			||||||
 | 
					        self._move_flag = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def unconditional_embedding(self, batch_size):
 | 
				
			||||||
 | 
					        empty_text = [""] * batch_size
 | 
				
			||||||
 | 
					        empty_z = self.forward(empty_text)
 | 
				
			||||||
 | 
					        return empty_z
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, text):
 | 
				
			||||||
 | 
					        self.move()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_encoding = self.tokenizer(
 | 
				
			||||||
 | 
					            text,
 | 
				
			||||||
 | 
					            truncation=True,
 | 
				
			||||||
 | 
					            max_length=self.max_length,
 | 
				
			||||||
 | 
					            return_length=True,
 | 
				
			||||||
 | 
					            return_overflowing_tokens=False,
 | 
				
			||||||
 | 
					            padding="max_length",
 | 
				
			||||||
 | 
					            return_tensors="pt",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tokens = batch_encoding["input_ids"].to(self.device)
 | 
				
			||||||
 | 
					        outputs = self.clip(input_ids=tokens)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        z = outputs.last_hidden_state
 | 
				
			||||||
 | 
					        return z
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, text):
 | 
				
			||||||
 | 
					        batch_size = len(text)
 | 
				
			||||||
 | 
					        batch_mask = torch.rand((batch_size,))
 | 
				
			||||||
 | 
					        for i in range(batch_size):
 | 
				
			||||||
 | 
					            if batch_mask[i] < self.zero_embedding_radio:
 | 
				
			||||||
 | 
					                text[i] = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self(text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FrozenAlignedCLIPTextEmbedder(AbstractEncoder):
 | 
				
			||||||
 | 
					    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        version="openai/clip-vit-large-patch14",
 | 
				
			||||||
 | 
					        tokenizer_version=None,
 | 
				
			||||||
 | 
					        device="cuda",
 | 
				
			||||||
 | 
					        max_length=77,
 | 
				
			||||||
 | 
					        zero_embedding_radio: float = 0.1,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.device = device
 | 
				
			||||||
 | 
					        self.max_length = max_length
 | 
				
			||||||
 | 
					        self.zero_embedding_radio = zero_embedding_radio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.clip_dict = OrderedDict()
 | 
				
			||||||
 | 
					        self.clip_name = os.path.split(version)[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        transformer = CLIPModel.from_pretrained(version).text_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for param in transformer.parameters():
 | 
				
			||||||
 | 
					            param.requires_grad = False
 | 
				
			||||||
 | 
					        self.clip_dict[self.clip_name] = transformer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._move_flag = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def clip(self):
 | 
				
			||||||
 | 
					        return self.clip_dict[self.clip_name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def move(self):
 | 
				
			||||||
 | 
					        if self._move_flag:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
 | 
				
			||||||
 | 
					        self._move_flag = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def unconditional_embedding(self, batch_size):
 | 
				
			||||||
 | 
					        empty_text = [""] * batch_size
 | 
				
			||||||
 | 
					        empty_z = self.forward(empty_text)
 | 
				
			||||||
 | 
					        return empty_z
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, text):
 | 
				
			||||||
 | 
					        self.move()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_encoding = self.tokenizer(
 | 
				
			||||||
 | 
					            text,
 | 
				
			||||||
 | 
					            truncation=True,
 | 
				
			||||||
 | 
					            max_length=self.max_length,
 | 
				
			||||||
 | 
					            return_length=True,
 | 
				
			||||||
 | 
					            return_overflowing_tokens=False,
 | 
				
			||||||
 | 
					            padding="max_length",
 | 
				
			||||||
 | 
					            return_tensors="pt",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tokens = batch_encoding["input_ids"].to(self.device)
 | 
				
			||||||
 | 
					        outputs = self.clip(input_ids=tokens)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        z = outputs.last_hidden_state
 | 
				
			||||||
 | 
					        return z
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, text):
 | 
				
			||||||
 | 
					        batch_size = len(text)
 | 
				
			||||||
 | 
					        batch_mask = torch.rand((batch_size,))
 | 
				
			||||||
 | 
					        for i in range(batch_size):
 | 
				
			||||||
 | 
					            if batch_mask[i] < self.zero_embedding_radio:
 | 
				
			||||||
 | 
					                text[i] = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self(text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FrozenCLIPImageEmbedder(AbstractEncoder):
 | 
				
			||||||
 | 
					    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            version="openai/clip-vit-large-patch14",
 | 
				
			||||||
 | 
					            device="cuda",
 | 
				
			||||||
 | 
					            zero_embedding_radio=0.1,
 | 
				
			||||||
 | 
					            normalize_embedding=True,
 | 
				
			||||||
 | 
					            num_projection_vector=0,
 | 
				
			||||||
 | 
					            linear_mapping_bias=True,
 | 
				
			||||||
 | 
					            reverse_visual_projection=False,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.device = device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.clip_dict = OrderedDict()
 | 
				
			||||||
 | 
					        self.clip_name = os.path.split(version)[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        clip_model = CLIPModel.from_pretrained(version)
 | 
				
			||||||
 | 
					        clip_model.text_model = None
 | 
				
			||||||
 | 
					        clip_model.text_projection = None
 | 
				
			||||||
 | 
					        clip_model = clip_model.eval()
 | 
				
			||||||
 | 
					        for param in self.parameters():
 | 
				
			||||||
 | 
					            param.requires_grad = False
 | 
				
			||||||
 | 
					        self.clip_dict[self.clip_name] = clip_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.transform = transforms.Compose(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
 | 
				
			||||||
 | 
					                transforms.CenterCrop(224),  # crop a (224, 224) square
 | 
				
			||||||
 | 
					                transforms.Normalize(
 | 
				
			||||||
 | 
					                    mean=[0.48145466, 0.4578275, 0.40821073],
 | 
				
			||||||
 | 
					                    std=[0.26862954, 0.26130258, 0.27577711],
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.zero_embedding_radio = zero_embedding_radio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.num_projection_vector = num_projection_vector
 | 
				
			||||||
 | 
					        self.reverse_visual_projection = reverse_visual_projection
 | 
				
			||||||
 | 
					        self.normalize_embedding = normalize_embedding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embedding_dim = (
 | 
				
			||||||
 | 
					            clip_model.visual_projection.in_features
 | 
				
			||||||
 | 
					            if reverse_visual_projection
 | 
				
			||||||
 | 
					            else clip_model.visual_projection.out_features
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.embedding_dim = embedding_dim
 | 
				
			||||||
 | 
					        if self.num_projection_vector > 0:
 | 
				
			||||||
 | 
					            self.projection = nn.Linear(
 | 
				
			||||||
 | 
					                embedding_dim,
 | 
				
			||||||
 | 
					                clip_model.visual_projection.out_features * num_projection_vector,
 | 
				
			||||||
 | 
					                bias=linear_mapping_bias,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            nn.init.normal_(self.projection.weight, std=embedding_dim ** -0.5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._move_flag = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def clip(self):
 | 
				
			||||||
 | 
					        return self.clip_dict[self.clip_name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def unconditional_embedding(self, batch_size):
 | 
				
			||||||
 | 
					        zero = torch.zeros(
 | 
				
			||||||
 | 
					            batch_size,
 | 
				
			||||||
 | 
					            1,
 | 
				
			||||||
 | 
					            self.embedding_dim,
 | 
				
			||||||
 | 
					            device=self.device,
 | 
				
			||||||
 | 
					            dtype=self.clip.visual_projection.weight.dtype,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if self.num_projection_vector > 0:
 | 
				
			||||||
 | 
					            zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
 | 
				
			||||||
 | 
					        return zero
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
 | 
				
			||||||
 | 
					        if value_range is not None:
 | 
				
			||||||
 | 
					            low, high = value_range
 | 
				
			||||||
 | 
					            image = (image - low) / (high - low)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.reverse_visual_projection:
 | 
				
			||||||
 | 
					            z = self.clip.vision_model(self.transform(image))[1]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            z = self.clip.get_image_features(self.transform(image))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.normalize_embedding:
 | 
				
			||||||
 | 
					            z = z / z.norm(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					        if z.ndim == 2:
 | 
				
			||||||
 | 
					            z = z.unsqueeze(dim=-2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if zero_embedding_radio > 0:
 | 
				
			||||||
 | 
					            mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) < zero_embedding_radio
 | 
				
			||||||
 | 
					            z = z * mask.to(z)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.num_projection_vector > 0:
 | 
				
			||||||
 | 
					            z = self.projection(z).view(len(image), self.num_projection_vector, -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return z
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def move(self):
 | 
				
			||||||
 | 
					        if self._move_flag:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
 | 
				
			||||||
 | 
					        self._move_flag = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, image):
 | 
				
			||||||
 | 
					        self.move()
 | 
				
			||||||
 | 
					        return self(image, zero_embedding_radio=self.zero_embedding_radio)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FrozenCLIPImageGridEmbedder(AbstractEncoder):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            version="openai/clip-vit-large-patch14",
 | 
				
			||||||
 | 
					            device="cuda",
 | 
				
			||||||
 | 
					            zero_embedding_radio=0.1,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.device = device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.clip_dict = OrderedDict()
 | 
				
			||||||
 | 
					        self.clip_name = os.path.split(version)[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        clip_model: CLIPModel = CLIPModel.from_pretrained(version)
 | 
				
			||||||
 | 
					        clip_model.text_model = None
 | 
				
			||||||
 | 
					        clip_model.text_projection = None
 | 
				
			||||||
 | 
					        clip_model = clip_model.eval()
 | 
				
			||||||
 | 
					        for param in self.parameters():
 | 
				
			||||||
 | 
					            param.requires_grad = False
 | 
				
			||||||
 | 
					        self.clip_dict[self.clip_name] = clip_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.transform = transforms.Compose(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                transforms.Resize(224, transforms.InterpolationMode.BILINEAR, antialias=True),
 | 
				
			||||||
 | 
					                transforms.CenterCrop(224),  # crop a (224, 224) square
 | 
				
			||||||
 | 
					                transforms.Normalize(
 | 
				
			||||||
 | 
					                    mean=[0.48145466, 0.4578275, 0.40821073],
 | 
				
			||||||
 | 
					                    std=[0.26862954, 0.26130258, 0.27577711],
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.zero_embedding_radio = zero_embedding_radio
 | 
				
			||||||
 | 
					        self.embedding_dim = clip_model.vision_embed_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._move_flag = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def clip(self):
 | 
				
			||||||
 | 
					        return self.clip_dict[self.clip_name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def move(self):
 | 
				
			||||||
 | 
					        if self._move_flag:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
 | 
				
			||||||
 | 
					        self._move_flag = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def unconditional_embedding(self, batch_size):
 | 
				
			||||||
 | 
					        zero = torch.zeros(
 | 
				
			||||||
 | 
					            batch_size,
 | 
				
			||||||
 | 
					            self.clip.vision_model.embeddings.num_positions,
 | 
				
			||||||
 | 
					            self.embedding_dim,
 | 
				
			||||||
 | 
					            device=self.device,
 | 
				
			||||||
 | 
					            dtype=self.clip.visual_projection.weight.dtype,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return zero
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
 | 
				
			||||||
 | 
					        self.move()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if value_range is not None:
 | 
				
			||||||
 | 
					            low, high = value_range
 | 
				
			||||||
 | 
					            image = (image - low) / (high - low)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        z = self.clip.vision_model(self.transform(image)).last_hidden_state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if zero_embedding_radio > 0:
 | 
				
			||||||
 | 
					            mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
 | 
				
			||||||
 | 
					            z = z * mask.to(z)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return z
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, image):
 | 
				
			||||||
 | 
					        return self(image, zero_embedding_radio=self.zero_embedding_radio)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MoECLIPImageEncoder(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            versions,
 | 
				
			||||||
 | 
					            hidden_state_dim,
 | 
				
			||||||
 | 
					            num_projection_vector=8,
 | 
				
			||||||
 | 
					            zero_embedding_radio=0.1,
 | 
				
			||||||
 | 
					            device="cuda",
 | 
				
			||||||
 | 
					            precision="fp16",
 | 
				
			||||||
 | 
					            normalize=False,
 | 
				
			||||||
 | 
					            clip_max=0,
 | 
				
			||||||
 | 
					            transform_type="base",
 | 
				
			||||||
 | 
					            argument_p=0.2,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.device = torch.device(device)
 | 
				
			||||||
 | 
					        self.hidden_state_dim = hidden_state_dim
 | 
				
			||||||
 | 
					        self.zero_embedding_radio = zero_embedding_radio
 | 
				
			||||||
 | 
					        self.num_projection_vector = num_projection_vector
 | 
				
			||||||
 | 
					        self.dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[precision]
 | 
				
			||||||
 | 
					        self.normalize = normalize
 | 
				
			||||||
 | 
					        self.clip_max = clip_max
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if transform_type == "base":
 | 
				
			||||||
 | 
					            self.transform = transforms.Compose(
 | 
				
			||||||
 | 
					                [
 | 
				
			||||||
 | 
					                    transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
 | 
				
			||||||
 | 
					                    transforms.CenterCrop(224),  # crop a (224, 224) square
 | 
				
			||||||
 | 
					                    transforms.Normalize(
 | 
				
			||||||
 | 
					                        mean=[0.48145466, 0.4578275, 0.40821073],
 | 
				
			||||||
 | 
					                        std=[0.26862954, 0.26130258, 0.27577711],
 | 
				
			||||||
 | 
					                    ),
 | 
				
			||||||
 | 
					                ]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        elif transform_type == "crop_blur_resize":
 | 
				
			||||||
 | 
					            self.transform = transforms.Compose(
 | 
				
			||||||
 | 
					                [
 | 
				
			||||||
 | 
					                    transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
 | 
				
			||||||
 | 
					                    transforms.CenterCrop(224),  # crop a (224, 224) square
 | 
				
			||||||
 | 
					                    transforms.RandomApply(
 | 
				
			||||||
 | 
					                        transforms=[
 | 
				
			||||||
 | 
					                            transforms.RandomResizedCrop(
 | 
				
			||||||
 | 
					                                size=224,
 | 
				
			||||||
 | 
					                                scale=(0.8, 1.0),
 | 
				
			||||||
 | 
					                                ratio=(0.99, 1.01),
 | 
				
			||||||
 | 
					                                interpolation=transforms.InterpolationMode.BICUBIC,
 | 
				
			||||||
 | 
					                            ),
 | 
				
			||||||
 | 
					                        ],
 | 
				
			||||||
 | 
					                        p=argument_p,
 | 
				
			||||||
 | 
					                    ),
 | 
				
			||||||
 | 
					                    transforms.RandomApply(
 | 
				
			||||||
 | 
					                        transforms=[
 | 
				
			||||||
 | 
					                            transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 5)),
 | 
				
			||||||
 | 
					                        ],
 | 
				
			||||||
 | 
					                        p=argument_p,
 | 
				
			||||||
 | 
					                    ),
 | 
				
			||||||
 | 
					                    transforms.RandomApply(
 | 
				
			||||||
 | 
					                        transforms=[
 | 
				
			||||||
 | 
					                            RandomResize(size=224, resize_radio=(0.2, 1)),
 | 
				
			||||||
 | 
					                        ],
 | 
				
			||||||
 | 
					                        p=argument_p,
 | 
				
			||||||
 | 
					                    ),
 | 
				
			||||||
 | 
					                    transforms.Normalize(
 | 
				
			||||||
 | 
					                        mean=[0.48145466, 0.4578275, 0.40821073],
 | 
				
			||||||
 | 
					                        std=[0.26862954, 0.26130258, 0.27577711],
 | 
				
			||||||
 | 
					                    ),
 | 
				
			||||||
 | 
					                ]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ValueError(f"invalid {transform_type=}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(versions, str):
 | 
				
			||||||
 | 
					            versions = (versions,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 如果直接把clips定位为当前类的子module,1. 会在保存ckp时存无用的多个权重。 2. pl会调用to,导致layer_norm的权重也被转换成fp16
 | 
				
			||||||
 | 
					        clips = OrderedDict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for v in versions:
 | 
				
			||||||
 | 
					            # 因为clips不是子module,直接指定device="cuda"会错误地导致clip模型权重都被放到cuda:0上。
 | 
				
			||||||
 | 
					            clips[v], _ = clip.load(name=v, device="cpu", jit=False, download_root=None)
 | 
				
			||||||
 | 
					            delattr(clips[v], "transformer")
 | 
				
			||||||
 | 
					            clips[v].eval()
 | 
				
			||||||
 | 
					            clips[v].requires_grad_(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.clips_hidden_dim = sum(clips[v].ln_final.weight.size(0) for v in clips)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.num_projection_vector == 0:
 | 
				
			||||||
 | 
					            self.projection = nn.Identity()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.projection = nn.Linear(self.clips_hidden_dim, hidden_state_dim * self.num_projection_vector, bias=True)
 | 
				
			||||||
 | 
					            self.projection.to(dtype=self.dtype)
 | 
				
			||||||
 | 
					            nn.init.normal_(self.projection.weight, std=self.clips_hidden_dim ** -0.5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.clips = clips
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._move_flag = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def move(self):
 | 
				
			||||||
 | 
					        if self._move_flag:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def convert_weights(model: nn.Module):
 | 
				
			||||||
 | 
					            """Convert applicable model parameters to fp16"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def _convert_weights_to_fp16(l):
 | 
				
			||||||
 | 
					                if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
 | 
				
			||||||
 | 
					                    l.weight.data = l.weight.data.type(self.dtype)
 | 
				
			||||||
 | 
					                    if l.bias is not None:
 | 
				
			||||||
 | 
					                        l.bias.data = l.bias.data.type(self.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if isinstance(l, nn.MultiheadAttention):
 | 
				
			||||||
 | 
					                    for attr in [
 | 
				
			||||||
 | 
					                        *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
 | 
				
			||||||
 | 
					                        "in_proj_bias",
 | 
				
			||||||
 | 
					                        "bias_k",
 | 
				
			||||||
 | 
					                        "bias_v",
 | 
				
			||||||
 | 
					                    ]:
 | 
				
			||||||
 | 
					                        tensor = getattr(l, attr)
 | 
				
			||||||
 | 
					                        if tensor is not None:
 | 
				
			||||||
 | 
					                            tensor.data = tensor.data.type(self.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                for name in ["text_projection", "proj"]:
 | 
				
			||||||
 | 
					                    if hasattr(l, name):
 | 
				
			||||||
 | 
					                        attr = getattr(l, name)
 | 
				
			||||||
 | 
					                        if attr is not None:
 | 
				
			||||||
 | 
					                            attr.data = attr.data.type(self.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            model.apply(_convert_weights_to_fp16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for k in self.clips:
 | 
				
			||||||
 | 
					            self.clips[k].to(self.device)
 | 
				
			||||||
 | 
					            convert_weights(self.clips[k])  # fp32 -> self.dtype
 | 
				
			||||||
 | 
					        self._move_flag = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def unconditional_embedding(self, batch_size=None):
 | 
				
			||||||
 | 
					        zero = torch.zeros(
 | 
				
			||||||
 | 
					            batch_size,
 | 
				
			||||||
 | 
					            self.clips_hidden_dim,
 | 
				
			||||||
 | 
					            device=self.device,
 | 
				
			||||||
 | 
					            dtype=self.dtype,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if self.num_projection_vector > 0:
 | 
				
			||||||
 | 
					            zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
 | 
				
			||||||
 | 
					        return zero
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def convert_embedding(self, z):
 | 
				
			||||||
 | 
					        if self.num_projection_vector > 0:
 | 
				
			||||||
 | 
					            z = self.projection(z.type(self.projection.weight.dtype)).view(len(z), self.num_projection_vector, -1)
 | 
				
			||||||
 | 
					        return z
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
 | 
				
			||||||
 | 
					        if value_range is not None:
 | 
				
			||||||
 | 
					            low, high = value_range
 | 
				
			||||||
 | 
					            image = (image - low) / (high - low)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        image = self.transform(image)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            embs = []
 | 
				
			||||||
 | 
					            for v in self.clips:
 | 
				
			||||||
 | 
					                x = self.clips[v].encode_image(image)
 | 
				
			||||||
 | 
					                if self.normalize:
 | 
				
			||||||
 | 
					                    x = x / x.norm(p=2, dim=-1, keepdim=True) * (x.size(-1) ** 0.5)
 | 
				
			||||||
 | 
					                    # clip_max only works with normalization
 | 
				
			||||||
 | 
					                    if self.clip_max > 0:
 | 
				
			||||||
 | 
					                        x = x.clamp(-self.clip_max, self.clip_max)
 | 
				
			||||||
 | 
					                embs.append(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            z = torch.cat(embs, dim=-1)
 | 
				
			||||||
 | 
					            if self.normalize:
 | 
				
			||||||
 | 
					                z /= z.size(-1) ** 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if zero_embedding_radio > 0:
 | 
				
			||||||
 | 
					            mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
 | 
				
			||||||
 | 
					            z = z + mask.to(z)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.num_projection_vector > 0:
 | 
				
			||||||
 | 
					            z = self.projection(z).view(len(image), self.num_projection_vector, -1)
 | 
				
			||||||
 | 
					        return z
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, image):
 | 
				
			||||||
 | 
					        self.move()
 | 
				
			||||||
 | 
					        return self(image, zero_embedding_radio=self.zero_embedding_radio)
 | 
				
			||||||
							
								
								
									
										3
									
								
								primitive_anything/michelangelo/models/modules/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										3
									
								
								primitive_anything/michelangelo/models/modules/__init__.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .checkpoint import checkpoint
 | 
				
			||||||
							
								
								
									
										69
									
								
								primitive_anything/michelangelo/models/modules/checkpoint.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										69
									
								
								primitive_anything/michelangelo/models/modules/checkpoint.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,69 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from typing import Callable, Iterable, Sequence, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def checkpoint(
 | 
				
			||||||
 | 
					    func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
 | 
				
			||||||
 | 
					    inputs: Sequence[torch.Tensor],
 | 
				
			||||||
 | 
					    params: Iterable[torch.Tensor],
 | 
				
			||||||
 | 
					    flag: bool,
 | 
				
			||||||
 | 
					    use_deepspeed: bool = False
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Evaluate a function without caching intermediate activations, allowing for
 | 
				
			||||||
 | 
					    reduced memory at the expense of extra compute in the backward pass.
 | 
				
			||||||
 | 
					    :param func: the function to evaluate.
 | 
				
			||||||
 | 
					    :param inputs: the argument sequence to pass to `func`.
 | 
				
			||||||
 | 
					    :param params: a sequence of parameters `func` depends on but does not
 | 
				
			||||||
 | 
					                   explicitly take as arguments.
 | 
				
			||||||
 | 
					    :param flag: if False, disable gradient checkpointing.
 | 
				
			||||||
 | 
					    :param use_deepspeed: if True, use deepspeed
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if flag:
 | 
				
			||||||
 | 
					        if use_deepspeed:
 | 
				
			||||||
 | 
					            import deepspeed
 | 
				
			||||||
 | 
					            return deepspeed.checkpointing.checkpoint(func, *inputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        args = tuple(inputs) + tuple(params)
 | 
				
			||||||
 | 
					        return CheckpointFunction.apply(func, len(inputs), *args)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return func(*inputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CheckpointFunction(torch.autograd.Function):
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    @torch.cuda.amp.custom_fwd
 | 
				
			||||||
 | 
					    def forward(ctx, run_function, length, *args):
 | 
				
			||||||
 | 
					        ctx.run_function = run_function
 | 
				
			||||||
 | 
					        ctx.input_tensors = list(args[:length])
 | 
				
			||||||
 | 
					        ctx.input_params = list(args[length:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            output_tensors = ctx.run_function(*ctx.input_tensors)
 | 
				
			||||||
 | 
					        return output_tensors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    @torch.cuda.amp.custom_bwd
 | 
				
			||||||
 | 
					    def backward(ctx, *output_grads):
 | 
				
			||||||
 | 
					        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
 | 
				
			||||||
 | 
					        with torch.enable_grad():
 | 
				
			||||||
 | 
					            # Fixes a bug where the first op in run_function modifies the
 | 
				
			||||||
 | 
					            # Tensor storage in place, which is not allowed for detach()'d
 | 
				
			||||||
 | 
					            # Tensors.
 | 
				
			||||||
 | 
					            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
 | 
				
			||||||
 | 
					            output_tensors = ctx.run_function(*shallow_copies)
 | 
				
			||||||
 | 
					        input_grads = torch.autograd.grad(
 | 
				
			||||||
 | 
					            output_tensors,
 | 
				
			||||||
 | 
					            ctx.input_tensors + ctx.input_params,
 | 
				
			||||||
 | 
					            output_grads,
 | 
				
			||||||
 | 
					            allow_unused=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        del ctx.input_tensors
 | 
				
			||||||
 | 
					        del ctx.input_params
 | 
				
			||||||
 | 
					        del output_tensors
 | 
				
			||||||
 | 
					        return (None, None) + input_grads
 | 
				
			||||||
							
								
								
									
										218
									
								
								primitive_anything/michelangelo/models/modules/diffusion_transformer.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										218
									
								
								primitive_anything/michelangelo/models/modules/diffusion_transformer.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,218 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .checkpoint import checkpoint
 | 
				
			||||||
 | 
					from .transformer_blocks import (
 | 
				
			||||||
 | 
					    init_linear,
 | 
				
			||||||
 | 
					    MLP,
 | 
				
			||||||
 | 
					    MultiheadCrossAttention,
 | 
				
			||||||
 | 
					    MultiheadAttention,
 | 
				
			||||||
 | 
					    ResidualAttentionBlock
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AdaLayerNorm(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 device: torch.device,
 | 
				
			||||||
 | 
					                 dtype: torch.dtype,
 | 
				
			||||||
 | 
					                 width: int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.silu = nn.SiLU(inplace=True)
 | 
				
			||||||
 | 
					        self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x, timestep):
 | 
				
			||||||
 | 
					        emb = self.linear(timestep)
 | 
				
			||||||
 | 
					        scale, shift = torch.chunk(emb, 2, dim=2)
 | 
				
			||||||
 | 
					        x = self.layernorm(x) * (1 + scale) + shift
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DitBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            *,
 | 
				
			||||||
 | 
					            device: torch.device,
 | 
				
			||||||
 | 
					            dtype: torch.dtype,
 | 
				
			||||||
 | 
					            n_ctx: int,
 | 
				
			||||||
 | 
					            width: int,
 | 
				
			||||||
 | 
					            heads: int,
 | 
				
			||||||
 | 
					            context_dim: int,
 | 
				
			||||||
 | 
					            qkv_bias: bool = False,
 | 
				
			||||||
 | 
					            init_scale: float = 1.0,
 | 
				
			||||||
 | 
					            use_checkpoint: bool = False
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.use_checkpoint = use_checkpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.attn = MultiheadAttention(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            n_ctx=n_ctx,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.ln_1 = AdaLayerNorm(device, dtype, width)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if context_dim is not None:
 | 
				
			||||||
 | 
					            self.ln_2 = AdaLayerNorm(device, dtype, width)
 | 
				
			||||||
 | 
					            self.cross_attn = MultiheadCrossAttention(
 | 
				
			||||||
 | 
					                device=device,
 | 
				
			||||||
 | 
					                dtype=dtype,
 | 
				
			||||||
 | 
					                width=width,
 | 
				
			||||||
 | 
					                heads=heads,
 | 
				
			||||||
 | 
					                data_width=context_dim,
 | 
				
			||||||
 | 
					                init_scale=init_scale,
 | 
				
			||||||
 | 
					                qkv_bias=qkv_bias
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
 | 
				
			||||||
 | 
					        self.ln_3 = AdaLayerNorm(device, dtype, width)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
 | 
				
			||||||
 | 
					        return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
 | 
				
			||||||
 | 
					        x = x + self.attn(self.ln_1(x, t))
 | 
				
			||||||
 | 
					        if context is not None:
 | 
				
			||||||
 | 
					            x = x + self.cross_attn(self.ln_2(x, t), context)
 | 
				
			||||||
 | 
					        x = x + self.mlp(self.ln_3(x, t))
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DiT(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            *,
 | 
				
			||||||
 | 
					            device: Optional[torch.device],
 | 
				
			||||||
 | 
					            dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					            n_ctx: int,
 | 
				
			||||||
 | 
					            width: int,
 | 
				
			||||||
 | 
					            layers: int,
 | 
				
			||||||
 | 
					            heads: int,
 | 
				
			||||||
 | 
					            context_dim: int,
 | 
				
			||||||
 | 
					            init_scale: float = 0.25,
 | 
				
			||||||
 | 
					            qkv_bias: bool = False,
 | 
				
			||||||
 | 
					            use_checkpoint: bool = False
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.n_ctx = n_ctx
 | 
				
			||||||
 | 
					        self.width = width
 | 
				
			||||||
 | 
					        self.layers = layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.resblocks = nn.ModuleList(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                DitBlock(
 | 
				
			||||||
 | 
					                    device=device,
 | 
				
			||||||
 | 
					                    dtype=dtype,
 | 
				
			||||||
 | 
					                    n_ctx=n_ctx,
 | 
				
			||||||
 | 
					                    width=width,
 | 
				
			||||||
 | 
					                    heads=heads,
 | 
				
			||||||
 | 
					                    context_dim=context_dim,
 | 
				
			||||||
 | 
					                    qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					                    init_scale=init_scale,
 | 
				
			||||||
 | 
					                    use_checkpoint=use_checkpoint
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                for _ in range(layers)
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
 | 
				
			||||||
 | 
					        for block in self.resblocks:
 | 
				
			||||||
 | 
					            x = block(x, t, context)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UNetDiffusionTransformer(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            *,
 | 
				
			||||||
 | 
					            device: Optional[torch.device],
 | 
				
			||||||
 | 
					            dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					            n_ctx: int,
 | 
				
			||||||
 | 
					            width: int,
 | 
				
			||||||
 | 
					            layers: int,
 | 
				
			||||||
 | 
					            heads: int,
 | 
				
			||||||
 | 
					            init_scale: float = 0.25,
 | 
				
			||||||
 | 
					            qkv_bias: bool = False,
 | 
				
			||||||
 | 
					            skip_ln: bool = False,
 | 
				
			||||||
 | 
					            use_checkpoint: bool = False
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.n_ctx = n_ctx
 | 
				
			||||||
 | 
					        self.width = width
 | 
				
			||||||
 | 
					        self.layers = layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.encoder = nn.ModuleList()
 | 
				
			||||||
 | 
					        for _ in range(layers):
 | 
				
			||||||
 | 
					            resblock = ResidualAttentionBlock(
 | 
				
			||||||
 | 
					                device=device,
 | 
				
			||||||
 | 
					                dtype=dtype,
 | 
				
			||||||
 | 
					                n_ctx=n_ctx,
 | 
				
			||||||
 | 
					                width=width,
 | 
				
			||||||
 | 
					                heads=heads,
 | 
				
			||||||
 | 
					                init_scale=init_scale,
 | 
				
			||||||
 | 
					                qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					                use_checkpoint=use_checkpoint
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.encoder.append(resblock)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.middle_block = ResidualAttentionBlock(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            n_ctx=n_ctx,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					            use_checkpoint=use_checkpoint
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.decoder = nn.ModuleList()
 | 
				
			||||||
 | 
					        for _ in range(layers):
 | 
				
			||||||
 | 
					            resblock = ResidualAttentionBlock(
 | 
				
			||||||
 | 
					                device=device,
 | 
				
			||||||
 | 
					                dtype=dtype,
 | 
				
			||||||
 | 
					                n_ctx=n_ctx,
 | 
				
			||||||
 | 
					                width=width,
 | 
				
			||||||
 | 
					                heads=heads,
 | 
				
			||||||
 | 
					                init_scale=init_scale,
 | 
				
			||||||
 | 
					                qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					                use_checkpoint=use_checkpoint
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            linear = nn.Linear(width * 2, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					            init_linear(linear, init_scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        enc_outputs = []
 | 
				
			||||||
 | 
					        for block in self.encoder:
 | 
				
			||||||
 | 
					            x = block(x)
 | 
				
			||||||
 | 
					            enc_outputs.append(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = self.middle_block(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
 | 
				
			||||||
 | 
					            x = torch.cat([enc_outputs.pop(), x], dim=-1)
 | 
				
			||||||
 | 
					            x = linear(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if layer_norm is not None:
 | 
				
			||||||
 | 
					                x = layer_norm(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            x = resblock(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
							
								
								
									
										100
									
								
								primitive_anything/michelangelo/models/modules/distributions.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										100
									
								
								primitive_anything/michelangelo/models/modules/distributions.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,100 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from typing import Union, List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AbstractDistribution(object):
 | 
				
			||||||
 | 
					    def sample(self):
 | 
				
			||||||
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def mode(self):
 | 
				
			||||||
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DiracDistribution(AbstractDistribution):
 | 
				
			||||||
 | 
					    def __init__(self, value):
 | 
				
			||||||
 | 
					        self.value = value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def sample(self):
 | 
				
			||||||
 | 
					        return self.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def mode(self):
 | 
				
			||||||
 | 
					        return self.value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DiagonalGaussianDistribution(object):
 | 
				
			||||||
 | 
					    def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
 | 
				
			||||||
 | 
					        self.feat_dim = feat_dim
 | 
				
			||||||
 | 
					        self.parameters = parameters
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(parameters, list):
 | 
				
			||||||
 | 
					            self.mean = parameters[0]
 | 
				
			||||||
 | 
					            self.logvar = parameters[1]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
 | 
				
			||||||
 | 
					        self.deterministic = deterministic
 | 
				
			||||||
 | 
					        self.std = torch.exp(0.5 * self.logvar)
 | 
				
			||||||
 | 
					        self.var = torch.exp(self.logvar)
 | 
				
			||||||
 | 
					        if self.deterministic:
 | 
				
			||||||
 | 
					            self.var = self.std = torch.zeros_like(self.mean)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def sample(self):
 | 
				
			||||||
 | 
					        x = self.mean + self.std * torch.randn_like(self.mean)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def kl(self, other=None, dims=(1, 2, 3)):
 | 
				
			||||||
 | 
					        if self.deterministic:
 | 
				
			||||||
 | 
					            return torch.Tensor([0.])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if other is None:
 | 
				
			||||||
 | 
					                return 0.5 * torch.mean(torch.pow(self.mean, 2)
 | 
				
			||||||
 | 
					                                        + self.var - 1.0 - self.logvar,
 | 
				
			||||||
 | 
					                                        dim=dims)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return 0.5 * torch.mean(
 | 
				
			||||||
 | 
					                    torch.pow(self.mean - other.mean, 2) / other.var
 | 
				
			||||||
 | 
					                    + self.var / other.var - 1.0 - self.logvar + other.logvar,
 | 
				
			||||||
 | 
					                    dim=dims)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def nll(self, sample, dims=(1, 2, 3)):
 | 
				
			||||||
 | 
					        if self.deterministic:
 | 
				
			||||||
 | 
					            return torch.Tensor([0.])
 | 
				
			||||||
 | 
					        logtwopi = np.log(2.0 * np.pi)
 | 
				
			||||||
 | 
					        return 0.5 * torch.sum(
 | 
				
			||||||
 | 
					            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
 | 
				
			||||||
 | 
					            dim=dims)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def mode(self):
 | 
				
			||||||
 | 
					        return self.mean
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def normal_kl(mean1, logvar1, mean2, logvar2):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
 | 
				
			||||||
 | 
					    Compute the KL divergence between two gaussians.
 | 
				
			||||||
 | 
					    Shapes are automatically broadcasted, so batches can be compared to
 | 
				
			||||||
 | 
					    scalars, among other use cases.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    tensor = None
 | 
				
			||||||
 | 
					    for obj in (mean1, logvar1, mean2, logvar2):
 | 
				
			||||||
 | 
					        if isinstance(obj, torch.Tensor):
 | 
				
			||||||
 | 
					            tensor = obj
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					    assert tensor is not None, "at least one argument must be a Tensor"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Force variances to be Tensors. Broadcasting helps convert scalars to
 | 
				
			||||||
 | 
					    # Tensors, but it does not work for torch.exp().
 | 
				
			||||||
 | 
					    logvar1, logvar2 = [
 | 
				
			||||||
 | 
					        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
 | 
				
			||||||
 | 
					        for x in (logvar1, logvar2)
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return 0.5 * (
 | 
				
			||||||
 | 
					            -1.0
 | 
				
			||||||
 | 
					            + logvar2
 | 
				
			||||||
 | 
					            - logvar1
 | 
				
			||||||
 | 
					            + torch.exp(logvar1 - logvar2)
 | 
				
			||||||
 | 
					            + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
							
								
								
									
										213
									
								
								primitive_anything/michelangelo/models/modules/embedder.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										213
									
								
								primitive_anything/michelangelo/models/modules/embedder.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,213 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FourierEmbedder(nn.Module):
 | 
				
			||||||
 | 
					    """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
 | 
				
			||||||
 | 
					    each feature dimension of `x[..., i]` into:
 | 
				
			||||||
 | 
					        [
 | 
				
			||||||
 | 
					            sin(x[..., i]),
 | 
				
			||||||
 | 
					            sin(f_1*x[..., i]),
 | 
				
			||||||
 | 
					            sin(f_2*x[..., i]),
 | 
				
			||||||
 | 
					            ...
 | 
				
			||||||
 | 
					            sin(f_N * x[..., i]),
 | 
				
			||||||
 | 
					            cos(x[..., i]),
 | 
				
			||||||
 | 
					            cos(f_1*x[..., i]),
 | 
				
			||||||
 | 
					            cos(f_2*x[..., i]),
 | 
				
			||||||
 | 
					            ...
 | 
				
			||||||
 | 
					            cos(f_N * x[..., i]),
 | 
				
			||||||
 | 
					            x[..., i]     # only present if include_input is True.
 | 
				
			||||||
 | 
					        ], here f_i is the frequency.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
 | 
				
			||||||
 | 
					    If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
 | 
				
			||||||
 | 
					    Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        num_freqs (int): the number of frequencies, default is 6;
 | 
				
			||||||
 | 
					        logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
 | 
				
			||||||
 | 
					            otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
 | 
				
			||||||
 | 
					        input_dim (int): the input dimension, default is 3;
 | 
				
			||||||
 | 
					        include_input (bool): include the input tensor or not, default is True.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Attributes:
 | 
				
			||||||
 | 
					        frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
 | 
				
			||||||
 | 
					                otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
 | 
				
			||||||
 | 
					            otherwise, it is input_dim * num_freqs * 2.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 num_freqs: int = 6,
 | 
				
			||||||
 | 
					                 logspace: bool = True,
 | 
				
			||||||
 | 
					                 input_dim: int = 3,
 | 
				
			||||||
 | 
					                 include_input: bool = True,
 | 
				
			||||||
 | 
					                 include_pi: bool = True) -> None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """The initialization"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if logspace:
 | 
				
			||||||
 | 
					            frequencies = 2.0 ** torch.arange(
 | 
				
			||||||
 | 
					                num_freqs,
 | 
				
			||||||
 | 
					                dtype=torch.float32
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            frequencies = torch.linspace(
 | 
				
			||||||
 | 
					                1.0,
 | 
				
			||||||
 | 
					                2.0 ** (num_freqs - 1),
 | 
				
			||||||
 | 
					                num_freqs,
 | 
				
			||||||
 | 
					                dtype=torch.float32
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if include_pi:
 | 
				
			||||||
 | 
					            frequencies *= torch.pi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.register_buffer("frequencies", frequencies, persistent=False)
 | 
				
			||||||
 | 
					        self.include_input = include_input
 | 
				
			||||||
 | 
					        self.num_freqs = num_freqs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.out_dim = self.get_dims(input_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_dims(self, input_dim):
 | 
				
			||||||
 | 
					        temp = 1 if self.include_input or self.num_freqs == 0 else 0
 | 
				
			||||||
 | 
					        out_dim = input_dim * (self.num_freqs * 2 + temp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return out_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        """ Forward process.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            x: tensor of shape [..., dim]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
 | 
				
			||||||
 | 
					                where temp is 1 if include_input is True and 0 otherwise.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.num_freqs > 0:
 | 
				
			||||||
 | 
					            embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
 | 
				
			||||||
 | 
					            if self.include_input:
 | 
				
			||||||
 | 
					                return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return torch.cat((embed.sin(), embed.cos()), dim=-1)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LearnedFourierEmbedder(nn.Module):
 | 
				
			||||||
 | 
					    """ following @crowsonkb "s lead with learned sinusoidal pos emb """
 | 
				
			||||||
 | 
					    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, in_channels, dim):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        assert (dim % 2) == 0
 | 
				
			||||||
 | 
					        half_dim = dim // 2
 | 
				
			||||||
 | 
					        per_channel_dim = half_dim // in_channels
 | 
				
			||||||
 | 
					        self.weights = nn.Parameter(torch.randn(per_channel_dim))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            x (torch.FloatTensor): [..., c]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            x (torch.FloatTensor): [..., d]
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
 | 
				
			||||||
 | 
					        freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
 | 
				
			||||||
 | 
					        fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
 | 
				
			||||||
 | 
					        return fouriered
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TriplaneLearnedFourierEmbedder(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, in_channels, dim):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
 | 
				
			||||||
 | 
					        self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
 | 
				
			||||||
 | 
					        self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.out_dim = in_channels + dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        yz_embed = self.yz_plane_embedder(x)
 | 
				
			||||||
 | 
					        xz_embed = self.xz_plane_embedder(x)
 | 
				
			||||||
 | 
					        xy_embed = self.xy_plane_embedder(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embed = yz_embed + xz_embed + xy_embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def sequential_pos_embed(num_len, embed_dim):
 | 
				
			||||||
 | 
					    assert embed_dim % 2 == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pos = torch.arange(num_len, dtype=torch.float32)
 | 
				
			||||||
 | 
					    omega = torch.arange(embed_dim // 2, dtype=torch.float32)
 | 
				
			||||||
 | 
					    omega /= embed_dim / 2.
 | 
				
			||||||
 | 
					    omega = 1. / 10000 ** omega  # (D/2,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pos = pos.reshape(-1)  # (M,)
 | 
				
			||||||
 | 
					    out = torch.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    emb_sin = torch.sin(out)  # (M, D/2)
 | 
				
			||||||
 | 
					    emb_cos = torch.cos(out)  # (M, D/2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    embeddings = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return embeddings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def timestep_embedding(timesteps, dim, max_period=10000):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Create sinusoidal timestep embeddings.
 | 
				
			||||||
 | 
					    :param timesteps: a 1-D Tensor of N indices, one per batch element.
 | 
				
			||||||
 | 
					                      These may be fractional.
 | 
				
			||||||
 | 
					    :param dim: the dimension of the output.
 | 
				
			||||||
 | 
					    :param max_period: controls the minimum frequency of the embeddings.
 | 
				
			||||||
 | 
					    :return: an [N x dim] Tensor of positional embeddings.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    half = dim // 2
 | 
				
			||||||
 | 
					    freqs = torch.exp(
 | 
				
			||||||
 | 
					        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
 | 
				
			||||||
 | 
					    ).to(device=timesteps.device)
 | 
				
			||||||
 | 
					    args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
 | 
				
			||||||
 | 
					    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
 | 
				
			||||||
 | 
					    if dim % 2:
 | 
				
			||||||
 | 
					        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
 | 
				
			||||||
 | 
					    return embedding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4,
 | 
				
			||||||
 | 
					                 num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16,
 | 
				
			||||||
 | 
					                 log2_hashmap_size=19, desired_resolution=None):
 | 
				
			||||||
 | 
					    if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
 | 
				
			||||||
 | 
					        return nn.Identity(), input_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif embed_type == "fourier":
 | 
				
			||||||
 | 
					        embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim,
 | 
				
			||||||
 | 
					                                       logspace=True, include_input=True)
 | 
				
			||||||
 | 
					        return embedder_obj, embedder_obj.out_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif embed_type == "hashgrid":
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif embed_type == "sphere_harmonic":
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
 | 
				
			||||||
							
								
								
									
										286
									
								
								primitive_anything/michelangelo/models/modules/transformer_blocks.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										286
									
								
								primitive_anything/michelangelo/models/modules/transformer_blocks.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,286 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .checkpoint import checkpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def init_linear(l, stddev):
 | 
				
			||||||
 | 
					    nn.init.normal_(l.weight, std=stddev)
 | 
				
			||||||
 | 
					    if l.bias is not None:
 | 
				
			||||||
 | 
					        nn.init.constant_(l.bias, 0.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MultiheadAttention(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        device: torch.device,
 | 
				
			||||||
 | 
					        dtype: torch.dtype,
 | 
				
			||||||
 | 
					        n_ctx: int,
 | 
				
			||||||
 | 
					        width: int,
 | 
				
			||||||
 | 
					        heads: int,
 | 
				
			||||||
 | 
					        init_scale: float,
 | 
				
			||||||
 | 
					        qkv_bias: bool,
 | 
				
			||||||
 | 
					        flash: bool = False
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.n_ctx = n_ctx
 | 
				
			||||||
 | 
					        self.width = width
 | 
				
			||||||
 | 
					        self.heads = heads
 | 
				
			||||||
 | 
					        self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash)
 | 
				
			||||||
 | 
					        init_linear(self.c_qkv, init_scale)
 | 
				
			||||||
 | 
					        init_linear(self.c_proj, init_scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        x = self.c_qkv(x)
 | 
				
			||||||
 | 
					        x = checkpoint(self.attention, (x,), (), True)
 | 
				
			||||||
 | 
					        x = self.c_proj(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class QKVMultiheadAttention(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.device = device
 | 
				
			||||||
 | 
					        self.dtype = dtype
 | 
				
			||||||
 | 
					        self.heads = heads
 | 
				
			||||||
 | 
					        self.n_ctx = n_ctx
 | 
				
			||||||
 | 
					        self.flash = flash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, qkv):
 | 
				
			||||||
 | 
					        bs, n_ctx, width = qkv.shape
 | 
				
			||||||
 | 
					        attn_ch = width // self.heads // 3
 | 
				
			||||||
 | 
					        scale = 1 / math.sqrt(math.sqrt(attn_ch))
 | 
				
			||||||
 | 
					        qkv = qkv.view(bs, n_ctx, self.heads, -1)
 | 
				
			||||||
 | 
					        q, k, v = torch.split(qkv, attn_ch, dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.flash:
 | 
				
			||||||
 | 
					            out = F.scaled_dot_product_attention(q, k, v)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            weight = torch.einsum(
 | 
				
			||||||
 | 
					                "bthc,bshc->bhts", q * scale, k * scale
 | 
				
			||||||
 | 
					            )  # More stable with f16 than dividing afterwards
 | 
				
			||||||
 | 
					            wdtype = weight.dtype
 | 
				
			||||||
 | 
					            weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
 | 
				
			||||||
 | 
					            out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ResidualAttentionBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        device: torch.device,
 | 
				
			||||||
 | 
					        dtype: torch.dtype,
 | 
				
			||||||
 | 
					        n_ctx: int,
 | 
				
			||||||
 | 
					        width: int,
 | 
				
			||||||
 | 
					        heads: int,
 | 
				
			||||||
 | 
					        init_scale: float = 1.0,
 | 
				
			||||||
 | 
					        qkv_bias: bool = True,
 | 
				
			||||||
 | 
					        flash: bool = False,
 | 
				
			||||||
 | 
					        use_checkpoint: bool = False
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.use_checkpoint = use_checkpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.attn = MultiheadAttention(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            n_ctx=n_ctx,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					            flash=flash
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
 | 
				
			||||||
 | 
					        self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _forward(self, x: torch.Tensor):
 | 
				
			||||||
 | 
					        x = x + self.attn(self.ln_1(x))
 | 
				
			||||||
 | 
					        x = x + self.mlp(self.ln_2(x))
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor):
 | 
				
			||||||
 | 
					        return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MultiheadCrossAttention(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        device: torch.device,
 | 
				
			||||||
 | 
					        dtype: torch.dtype,
 | 
				
			||||||
 | 
					        width: int,
 | 
				
			||||||
 | 
					        heads: int,
 | 
				
			||||||
 | 
					        init_scale: float,
 | 
				
			||||||
 | 
					        qkv_bias: bool = True,
 | 
				
			||||||
 | 
					        flash: bool = False,
 | 
				
			||||||
 | 
					        n_data: Optional[int] = None,
 | 
				
			||||||
 | 
					        data_width: Optional[int] = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.n_data = n_data
 | 
				
			||||||
 | 
					        self.width = width
 | 
				
			||||||
 | 
					        self.heads = heads
 | 
				
			||||||
 | 
					        self.data_width = width if data_width is None else data_width
 | 
				
			||||||
 | 
					        self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.attention = QKVMultiheadCrossAttention(
 | 
				
			||||||
 | 
					            device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        init_linear(self.c_q, init_scale)
 | 
				
			||||||
 | 
					        init_linear(self.c_kv, init_scale)
 | 
				
			||||||
 | 
					        init_linear(self.c_proj, init_scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x, data):
 | 
				
			||||||
 | 
					        x = self.c_q(x)
 | 
				
			||||||
 | 
					        data = self.c_kv(data)
 | 
				
			||||||
 | 
					        x = checkpoint(self.attention, (x, data), (), True)
 | 
				
			||||||
 | 
					        x = self.c_proj(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class QKVMultiheadCrossAttention(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
 | 
				
			||||||
 | 
					                 flash: bool = False, n_data: Optional[int] = None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.device = device
 | 
				
			||||||
 | 
					        self.dtype = dtype
 | 
				
			||||||
 | 
					        self.heads = heads
 | 
				
			||||||
 | 
					        self.n_data = n_data
 | 
				
			||||||
 | 
					        self.flash = flash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, q, kv):
 | 
				
			||||||
 | 
					        _, n_ctx, _ = q.shape
 | 
				
			||||||
 | 
					        bs, n_data, width = kv.shape
 | 
				
			||||||
 | 
					        attn_ch = width // self.heads // 2
 | 
				
			||||||
 | 
					        scale = 1 / math.sqrt(math.sqrt(attn_ch))
 | 
				
			||||||
 | 
					        q = q.view(bs, n_ctx, self.heads, -1)
 | 
				
			||||||
 | 
					        kv = kv.view(bs, n_data, self.heads, -1)
 | 
				
			||||||
 | 
					        k, v = torch.split(kv, attn_ch, dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.flash:
 | 
				
			||||||
 | 
					            out = F.scaled_dot_product_attention(q, k, v)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            weight = torch.einsum(
 | 
				
			||||||
 | 
					                "bthc,bshc->bhts", q * scale, k * scale
 | 
				
			||||||
 | 
					            )  # More stable with f16 than dividing afterwards
 | 
				
			||||||
 | 
					            wdtype = weight.dtype
 | 
				
			||||||
 | 
					            weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
 | 
				
			||||||
 | 
					            out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ResidualCrossAttentionBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        device: Optional[torch.device],
 | 
				
			||||||
 | 
					        dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					        n_data: Optional[int] = None,
 | 
				
			||||||
 | 
					        width: int,
 | 
				
			||||||
 | 
					        heads: int,
 | 
				
			||||||
 | 
					        data_width: Optional[int] = None,
 | 
				
			||||||
 | 
					        init_scale: float = 0.25,
 | 
				
			||||||
 | 
					        qkv_bias: bool = True,
 | 
				
			||||||
 | 
					        flash: bool = False
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if data_width is None:
 | 
				
			||||||
 | 
					            data_width = width
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.attn = MultiheadCrossAttention(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            n_data=n_data,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            data_width=data_width,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					            flash=flash,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
 | 
				
			||||||
 | 
					        self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor, data: torch.Tensor):
 | 
				
			||||||
 | 
					        x = x + self.attn(self.ln_1(x), self.ln_2(data))
 | 
				
			||||||
 | 
					        x = x + self.mlp(self.ln_3(x))
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MLP(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 device: Optional[torch.device],
 | 
				
			||||||
 | 
					                 dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					                 width: int,
 | 
				
			||||||
 | 
					                 init_scale: float):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.width = width
 | 
				
			||||||
 | 
					        self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.gelu = nn.GELU()
 | 
				
			||||||
 | 
					        init_linear(self.c_fc, init_scale)
 | 
				
			||||||
 | 
					        init_linear(self.c_proj, init_scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        return self.c_proj(self.gelu(self.c_fc(x)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Transformer(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        device: Optional[torch.device],
 | 
				
			||||||
 | 
					        dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					        n_ctx: int,
 | 
				
			||||||
 | 
					        width: int,
 | 
				
			||||||
 | 
					        layers: int,
 | 
				
			||||||
 | 
					        heads: int,
 | 
				
			||||||
 | 
					        init_scale: float = 0.25,
 | 
				
			||||||
 | 
					        qkv_bias: bool = True,
 | 
				
			||||||
 | 
					        flash: bool = False,
 | 
				
			||||||
 | 
					        use_checkpoint: bool = False
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.n_ctx = n_ctx
 | 
				
			||||||
 | 
					        self.width = width
 | 
				
			||||||
 | 
					        self.layers = layers
 | 
				
			||||||
 | 
					        self.resblocks = nn.ModuleList(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                ResidualAttentionBlock(
 | 
				
			||||||
 | 
					                    device=device,
 | 
				
			||||||
 | 
					                    dtype=dtype,
 | 
				
			||||||
 | 
					                    n_ctx=n_ctx,
 | 
				
			||||||
 | 
					                    width=width,
 | 
				
			||||||
 | 
					                    heads=heads,
 | 
				
			||||||
 | 
					                    init_scale=init_scale,
 | 
				
			||||||
 | 
					                    qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					                    flash=flash,
 | 
				
			||||||
 | 
					                    use_checkpoint=use_checkpoint
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                for _ in range(layers)
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor):
 | 
				
			||||||
 | 
					        for block in self.resblocks:
 | 
				
			||||||
 | 
					            x = block(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
							
								
								
									
										308
									
								
								primitive_anything/michelangelo/models/modules/transformer_vit.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										308
									
								
								primitive_anything/michelangelo/models/modules/transformer_vit.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,308 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .checkpoint import checkpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _trunc_normal_(tensor, mean, std, a, b):
 | 
				
			||||||
 | 
					    # Cut & paste from PyTorch official master until it's in a few official releases - RW
 | 
				
			||||||
 | 
					    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
 | 
				
			||||||
 | 
					    def norm_cdf(x):
 | 
				
			||||||
 | 
					        # Computes standard normal cumulative distribution function
 | 
				
			||||||
 | 
					        return (1. + math.erf(x / math.sqrt(2.))) / 2.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (mean < a - 2 * std) or (mean > b + 2 * std):
 | 
				
			||||||
 | 
					        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
 | 
				
			||||||
 | 
					                      "The distribution of values may be incorrect.",
 | 
				
			||||||
 | 
					                      stacklevel=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Values are generated by using a truncated uniform distribution and
 | 
				
			||||||
 | 
					    # then using the inverse CDF for the normal distribution.
 | 
				
			||||||
 | 
					    # Get upper and lower cdf values
 | 
				
			||||||
 | 
					    l = norm_cdf((a - mean) / std)
 | 
				
			||||||
 | 
					    u = norm_cdf((b - mean) / std)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Uniformly fill tensor with values from [l, u], then translate to
 | 
				
			||||||
 | 
					    # [2l-1, 2u-1].
 | 
				
			||||||
 | 
					    tensor.uniform_(2 * l - 1, 2 * u - 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Use inverse cdf transform for normal distribution to get truncated
 | 
				
			||||||
 | 
					    # standard normal
 | 
				
			||||||
 | 
					    tensor.erfinv_()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Transform to proper mean, std
 | 
				
			||||||
 | 
					    tensor.mul_(std * math.sqrt(2.))
 | 
				
			||||||
 | 
					    tensor.add_(mean)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Clamp to ensure it's in the proper range
 | 
				
			||||||
 | 
					    tensor.clamp_(min=a, max=b)
 | 
				
			||||||
 | 
					    return tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
 | 
				
			||||||
 | 
					    # type: (Tensor | nn.Parameter, float, float, float, float) -> Tensor
 | 
				
			||||||
 | 
					    r"""Fills the input Tensor with values drawn from a truncated
 | 
				
			||||||
 | 
					    normal distribution. The values are effectively drawn from the
 | 
				
			||||||
 | 
					    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
 | 
				
			||||||
 | 
					    with values outside :math:`[a, b]` redrawn until they are within
 | 
				
			||||||
 | 
					    the bounds. The method used for generating the random values works
 | 
				
			||||||
 | 
					    best when :math:`a \leq \text{mean} \leq b`.
 | 
				
			||||||
 | 
					    NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
 | 
				
			||||||
 | 
					    applied while sampling the normal with mean/std applied, therefore a, b args
 | 
				
			||||||
 | 
					    should be adjusted to match the range of mean, std args.
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        tensor: an n-dimensional `torch.Tensor`
 | 
				
			||||||
 | 
					        mean: the mean of the normal distribution
 | 
				
			||||||
 | 
					        std: the standard deviation of the normal distribution
 | 
				
			||||||
 | 
					        a: the minimum cutoff value
 | 
				
			||||||
 | 
					        b: the maximum cutoff value
 | 
				
			||||||
 | 
					    Examples:
 | 
				
			||||||
 | 
					        >>> w = torch.empty(3, 5)
 | 
				
			||||||
 | 
					        >>> nn.init.trunc_normal_(w)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    with torch.no_grad():
 | 
				
			||||||
 | 
					        return _trunc_normal_(tensor, mean, std, a, b)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def init_weights(m):
 | 
				
			||||||
 | 
					    if isinstance(m, nn.Linear):
 | 
				
			||||||
 | 
					        trunc_normal_(m.weight, std=.02)
 | 
				
			||||||
 | 
					        if isinstance(m, nn.Linear) and m.bias is not None:
 | 
				
			||||||
 | 
					            nn.init.constant_(m.bias, 0)
 | 
				
			||||||
 | 
					    elif isinstance(m, nn.LayerNorm):
 | 
				
			||||||
 | 
					        nn.init.constant_(m.bias, 0)
 | 
				
			||||||
 | 
					        nn.init.constant_(m.weight, 1.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MultiheadAttention(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        device: torch.device,
 | 
				
			||||||
 | 
					        dtype: torch.dtype,
 | 
				
			||||||
 | 
					        n_ctx: int,
 | 
				
			||||||
 | 
					        width: int,
 | 
				
			||||||
 | 
					        heads: int,
 | 
				
			||||||
 | 
					        qkv_bias: bool
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.n_ctx = n_ctx
 | 
				
			||||||
 | 
					        self.width = width
 | 
				
			||||||
 | 
					        self.heads = heads
 | 
				
			||||||
 | 
					        self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        x = self.c_qkv(x)
 | 
				
			||||||
 | 
					        x = checkpoint(self.attention, (x,), (), True)
 | 
				
			||||||
 | 
					        x = self.c_proj(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class QKVMultiheadAttention(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.device = device
 | 
				
			||||||
 | 
					        self.dtype = dtype
 | 
				
			||||||
 | 
					        self.heads = heads
 | 
				
			||||||
 | 
					        self.n_ctx = n_ctx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, qkv):
 | 
				
			||||||
 | 
					        bs, n_ctx, width = qkv.shape
 | 
				
			||||||
 | 
					        attn_ch = width // self.heads // 3
 | 
				
			||||||
 | 
					        scale = 1 / math.sqrt(attn_ch)
 | 
				
			||||||
 | 
					        qkv = qkv.view(bs, n_ctx, self.heads, -1)
 | 
				
			||||||
 | 
					        q, k, v = torch.split(qkv, attn_ch, dim=-1)
 | 
				
			||||||
 | 
					        weight = torch.einsum("bthc,bshc->bhts", q, k) * scale
 | 
				
			||||||
 | 
					        wdtype = weight.dtype
 | 
				
			||||||
 | 
					        weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
 | 
				
			||||||
 | 
					        return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ResidualAttentionBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        device: torch.device,
 | 
				
			||||||
 | 
					        dtype: torch.dtype,
 | 
				
			||||||
 | 
					        n_ctx: int,
 | 
				
			||||||
 | 
					        width: int,
 | 
				
			||||||
 | 
					        heads: int,
 | 
				
			||||||
 | 
					        qkv_bias: bool = True,
 | 
				
			||||||
 | 
					        use_checkpoint: bool = False
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.use_checkpoint = use_checkpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.attn = MultiheadAttention(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            n_ctx=n_ctx,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.mlp = MLP(device=device, dtype=dtype, width=width)
 | 
				
			||||||
 | 
					        self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _forward(self, x: torch.Tensor):
 | 
				
			||||||
 | 
					        x = x + self.attn(self.ln_1(x))
 | 
				
			||||||
 | 
					        x = x + self.mlp(self.ln_2(x))
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor):
 | 
				
			||||||
 | 
					        return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MultiheadCrossAttention(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        device: torch.device,
 | 
				
			||||||
 | 
					        dtype: torch.dtype,
 | 
				
			||||||
 | 
					        width: int,
 | 
				
			||||||
 | 
					        heads: int,
 | 
				
			||||||
 | 
					        qkv_bias: bool = True,
 | 
				
			||||||
 | 
					        n_data: Optional[int] = None,
 | 
				
			||||||
 | 
					        data_width: Optional[int] = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.n_data = n_data
 | 
				
			||||||
 | 
					        self.width = width
 | 
				
			||||||
 | 
					        self.heads = heads
 | 
				
			||||||
 | 
					        self.data_width = width if data_width is None else data_width
 | 
				
			||||||
 | 
					        self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.attention = QKVMultiheadCrossAttention(
 | 
				
			||||||
 | 
					            device=device, dtype=dtype, heads=heads, n_data=n_data
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x, data):
 | 
				
			||||||
 | 
					        x = self.c_q(x)
 | 
				
			||||||
 | 
					        data = self.c_kv(data)
 | 
				
			||||||
 | 
					        x = checkpoint(self.attention, (x, data), (), True)
 | 
				
			||||||
 | 
					        x = self.c_proj(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class QKVMultiheadCrossAttention(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: Optional[int] = None):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.device = device
 | 
				
			||||||
 | 
					        self.dtype = dtype
 | 
				
			||||||
 | 
					        self.heads = heads
 | 
				
			||||||
 | 
					        self.n_data = n_data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, q, kv):
 | 
				
			||||||
 | 
					        _, n_ctx, _ = q.shape
 | 
				
			||||||
 | 
					        bs, n_data, width = kv.shape
 | 
				
			||||||
 | 
					        attn_ch = width // self.heads // 2
 | 
				
			||||||
 | 
					        scale = 1 / math.sqrt(attn_ch)
 | 
				
			||||||
 | 
					        q = q.view(bs, n_ctx, self.heads, -1)
 | 
				
			||||||
 | 
					        kv = kv.view(bs, n_data, self.heads, -1)
 | 
				
			||||||
 | 
					        k, v = torch.split(kv, attn_ch, dim=-1)
 | 
				
			||||||
 | 
					        weight = torch.einsum("bthc,bshc->bhts", q, k) * scale
 | 
				
			||||||
 | 
					        wdtype = weight.dtype
 | 
				
			||||||
 | 
					        weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
 | 
				
			||||||
 | 
					        return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ResidualCrossAttentionBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        device: Optional[torch.device],
 | 
				
			||||||
 | 
					        dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					        n_data: Optional[int] = None,
 | 
				
			||||||
 | 
					        width: int,
 | 
				
			||||||
 | 
					        heads: int,
 | 
				
			||||||
 | 
					        data_width: Optional[int] = None,
 | 
				
			||||||
 | 
					        qkv_bias: bool = True
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if data_width is None:
 | 
				
			||||||
 | 
					            data_width = width
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.attn = MultiheadCrossAttention(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            n_data=n_data,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            data_width=data_width,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.mlp = MLP(device=device, dtype=dtype, width=width)
 | 
				
			||||||
 | 
					        self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor, data: torch.Tensor):
 | 
				
			||||||
 | 
					        x = x + self.attn(self.ln_1(x), self.ln_2(data))
 | 
				
			||||||
 | 
					        x = x + self.mlp(self.ln_3(x))
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MLP(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 device: Optional[torch.device],
 | 
				
			||||||
 | 
					                 dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					                 width: int):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.width = width
 | 
				
			||||||
 | 
					        self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.gelu = nn.GELU()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        return self.c_proj(self.gelu(self.c_fc(x)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Transformer(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        device: Optional[torch.device],
 | 
				
			||||||
 | 
					        dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					        n_ctx: int,
 | 
				
			||||||
 | 
					        width: int,
 | 
				
			||||||
 | 
					        layers: int,
 | 
				
			||||||
 | 
					        heads: int,
 | 
				
			||||||
 | 
					        qkv_bias: bool = True,
 | 
				
			||||||
 | 
					        use_checkpoint: bool = False
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.n_ctx = n_ctx
 | 
				
			||||||
 | 
					        self.width = width
 | 
				
			||||||
 | 
					        self.layers = layers
 | 
				
			||||||
 | 
					        self.resblocks = nn.ModuleList(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                ResidualAttentionBlock(
 | 
				
			||||||
 | 
					                    device=device,
 | 
				
			||||||
 | 
					                    dtype=dtype,
 | 
				
			||||||
 | 
					                    n_ctx=n_ctx,
 | 
				
			||||||
 | 
					                    width=width,
 | 
				
			||||||
 | 
					                    heads=heads,
 | 
				
			||||||
 | 
					                    qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					                    use_checkpoint=use_checkpoint
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                for _ in range(layers)
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.apply(init_weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor):
 | 
				
			||||||
 | 
					        for block in self.resblocks:
 | 
				
			||||||
 | 
					            x = block(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
							
								
								
									
										1
									
								
								primitive_anything/michelangelo/models/tsal/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										1
									
								
								primitive_anything/michelangelo/models/tsal/__init__.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
							
								
								
									
										373
									
								
								primitive_anything/michelangelo/models/tsal/asl_pl_module.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										373
									
								
								primitive_anything/michelangelo/models/tsal/asl_pl_module.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,373 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import List, Tuple, Dict, Optional
 | 
				
			||||||
 | 
					from omegaconf import DictConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					from torch.optim import lr_scheduler
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ...utils import instantiate_from_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .inference_utils import extract_geometry
 | 
				
			||||||
 | 
					from .tsal_base import (
 | 
				
			||||||
 | 
					    AlignedShapeAsLatentModule,
 | 
				
			||||||
 | 
					    ShapeAsLatentModule,
 | 
				
			||||||
 | 
					    Latent2MeshOutput,
 | 
				
			||||||
 | 
					    AlignedMeshOutput
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AlignedShapeAsLatentPLModule(pl.LightningModule):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 shape_module_cfg,
 | 
				
			||||||
 | 
					                 aligned_module_cfg,
 | 
				
			||||||
 | 
					                 loss_cfg,
 | 
				
			||||||
 | 
					                 optimizer_cfg: Optional[DictConfig] = None,
 | 
				
			||||||
 | 
					                 ckpt_path: Optional[str] = None,
 | 
				
			||||||
 | 
					                 ignore_keys: Union[Tuple[str], List[str]] = ()):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        shape_model: ShapeAsLatentModule = instantiate_from_config(
 | 
				
			||||||
 | 
					            shape_module_cfg, device=None, dtype=None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.model: AlignedShapeAsLatentModule = instantiate_from_config(
 | 
				
			||||||
 | 
					            aligned_module_cfg, shape_model=shape_model
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.loss = instantiate_from_config(loss_cfg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.optimizer_cfg = optimizer_cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if ckpt_path is not None:
 | 
				
			||||||
 | 
					            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.save_hyperparameters()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_shape_model_only(self):
 | 
				
			||||||
 | 
					        self.model.set_shape_model_only()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def latent_shape(self):
 | 
				
			||||||
 | 
					        return self.model.shape_model.latent_shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def zero_rank(self):
 | 
				
			||||||
 | 
					        if self._trainer:
 | 
				
			||||||
 | 
					            zero_rank = self.trainer.local_rank == 0
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            zero_rank = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return zero_rank
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_from_ckpt(self, path, ignore_keys=()):
 | 
				
			||||||
 | 
					        state_dict = torch.load(path, map_location="cpu")["state_dict"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        keys = list(state_dict.keys())
 | 
				
			||||||
 | 
					        for k in keys:
 | 
				
			||||||
 | 
					            for ik in ignore_keys:
 | 
				
			||||||
 | 
					                if k.startswith(ik):
 | 
				
			||||||
 | 
					                    print("Deleting key {} from state_dict.".format(k))
 | 
				
			||||||
 | 
					                    del state_dict[k]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        missing, unexpected = self.load_state_dict(state_dict, strict=False)
 | 
				
			||||||
 | 
					        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
 | 
				
			||||||
 | 
					        if len(missing) > 0:
 | 
				
			||||||
 | 
					            print(f"Missing Keys: {missing}")
 | 
				
			||||||
 | 
					            print(f"Unexpected Keys: {unexpected}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def configure_optimizers(self) -> Tuple[List, List]:
 | 
				
			||||||
 | 
					        lr = self.learning_rate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        trainable_parameters = list(self.model.parameters())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.optimizer_cfg is None:
 | 
				
			||||||
 | 
					            optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
 | 
				
			||||||
 | 
					            schedulers = []
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
 | 
				
			||||||
 | 
					            scheduler_func = instantiate_from_config(
 | 
				
			||||||
 | 
					                self.optimizer_cfg.scheduler,
 | 
				
			||||||
 | 
					                max_decay_steps=self.trainer.max_steps,
 | 
				
			||||||
 | 
					                lr_max=lr
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            scheduler = {
 | 
				
			||||||
 | 
					                "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
 | 
				
			||||||
 | 
					                "interval": "step",
 | 
				
			||||||
 | 
					                "frequency": 1
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            optimizers = [optimizer]
 | 
				
			||||||
 | 
					            schedulers = [scheduler]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return optimizers, schedulers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self,
 | 
				
			||||||
 | 
					                surface: torch.FloatTensor,
 | 
				
			||||||
 | 
					                image: torch.FloatTensor,
 | 
				
			||||||
 | 
					                text: torch.FloatTensor,
 | 
				
			||||||
 | 
					                volume_queries: torch.FloatTensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            surface (torch.FloatTensor):
 | 
				
			||||||
 | 
					            image (torch.FloatTensor):
 | 
				
			||||||
 | 
					            text (torch.FloatTensor):
 | 
				
			||||||
 | 
					            volume_queries (torch.FloatTensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embed_outputs, shape_z = self.model(surface, image, text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
 | 
				
			||||||
 | 
					        latents = self.model.shape_model.decode(shape_zq)
 | 
				
			||||||
 | 
					        logits = self.model.shape_model.query_geometry(volume_queries, latents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return embed_outputs, logits, posterior
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, surface: torch.FloatTensor, sample_posterior=True):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pc = surface[..., 0:3]
 | 
				
			||||||
 | 
					        feats = surface[..., 3:6]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        shape_embed, shape_zq, posterior = self.model.shape_model.encode(
 | 
				
			||||||
 | 
					            pc=pc, feats=feats, sample_posterior=sample_posterior
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return shape_zq
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def encode_latents(self, surface: torch.FloatTensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pc = surface[..., 0:3]
 | 
				
			||||||
 | 
					        feats = surface[..., 3:6]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        shape_embed, shape_latents = self.model.shape_model.encode_latents(
 | 
				
			||||||
 | 
					            pc=pc, feats=feats
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        shape_embed = shape_embed.unsqueeze(1)
 | 
				
			||||||
 | 
					        assert shape_embed.shape[1] == 1 and shape_latents.shape[1] == 256
 | 
				
			||||||
 | 
					        cat_latents = torch.cat([shape_embed, shape_latents], dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return cat_latents
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_shape_latents(self, latents):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        shape_zq, posterior = self.model.shape_model.encode_kl_embed(latents, sample_posterior = False)
 | 
				
			||||||
 | 
					        return self.model.shape_model.decode(shape_zq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decode(self,
 | 
				
			||||||
 | 
					               z_q,
 | 
				
			||||||
 | 
					               bounds: Union[Tuple[float], List[float], float] = 1.1,
 | 
				
			||||||
 | 
					               octree_depth: int = 7,
 | 
				
			||||||
 | 
					               num_chunks: int = 10000) -> List[Latent2MeshOutput]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        latents = self.model.shape_model.decode(z_q)  # latents: [bs, num_latents, dim]
 | 
				
			||||||
 | 
					        outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def training_step(self, batch: Dict[str, torch.FloatTensor],
 | 
				
			||||||
 | 
					                      batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            batch (dict): the batch sample, and it contains:
 | 
				
			||||||
 | 
					                - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
 | 
				
			||||||
 | 
					                - image (torch.FloatTensor): [bs, 3, 224, 224]
 | 
				
			||||||
 | 
					                - text (torch.FloatTensor): [bs, num_templates, 77]
 | 
				
			||||||
 | 
					                - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            batch_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            optimizer_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            loss (torch.FloatTensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        surface = batch["surface"]
 | 
				
			||||||
 | 
					        image = batch["image"]
 | 
				
			||||||
 | 
					        text = batch["text"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        volume_queries = batch["geo_points"][..., 0:3]
 | 
				
			||||||
 | 
					        shape_labels = batch["geo_points"][..., -1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        aeloss, log_dict_ae = self.loss(
 | 
				
			||||||
 | 
					            **embed_outputs,
 | 
				
			||||||
 | 
					            posteriors=posteriors,
 | 
				
			||||||
 | 
					            shape_logits=shape_logits,
 | 
				
			||||||
 | 
					            shape_labels=shape_labels,
 | 
				
			||||||
 | 
					            split="train"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
 | 
				
			||||||
 | 
					                      sync_dist=False, rank_zero_only=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return aeloss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        surface = batch["surface"]
 | 
				
			||||||
 | 
					        image = batch["image"]
 | 
				
			||||||
 | 
					        text = batch["text"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        volume_queries = batch["geo_points"][..., 0:3]
 | 
				
			||||||
 | 
					        shape_labels = batch["geo_points"][..., -1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        aeloss, log_dict_ae = self.loss(
 | 
				
			||||||
 | 
					            **embed_outputs,
 | 
				
			||||||
 | 
					            posteriors=posteriors,
 | 
				
			||||||
 | 
					            shape_logits=shape_logits,
 | 
				
			||||||
 | 
					            shape_labels=shape_labels,
 | 
				
			||||||
 | 
					            split="val"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0],
 | 
				
			||||||
 | 
					                      sync_dist=False, rank_zero_only=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return aeloss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def visual_alignment(self,
 | 
				
			||||||
 | 
					                         surface: torch.FloatTensor,
 | 
				
			||||||
 | 
					                         image: torch.FloatTensor,
 | 
				
			||||||
 | 
					                         text: torch.FloatTensor,
 | 
				
			||||||
 | 
					                         description: Optional[List[str]] = None,
 | 
				
			||||||
 | 
					                         bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
 | 
				
			||||||
 | 
					                         octree_depth: int = 7,
 | 
				
			||||||
 | 
					                         num_chunks: int = 10000) -> List[AlignedMeshOutput]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            surface:
 | 
				
			||||||
 | 
					            image:
 | 
				
			||||||
 | 
					            text:
 | 
				
			||||||
 | 
					            description:
 | 
				
			||||||
 | 
					            bounds:
 | 
				
			||||||
 | 
					            octree_depth:
 | 
				
			||||||
 | 
					            num_chunks:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        outputs = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        device = surface.device
 | 
				
			||||||
 | 
					        bs = surface.shape[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embed_outputs, shape_z = self.model(surface, image, text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # calculate the similarity
 | 
				
			||||||
 | 
					        image_embed = embed_outputs["image_embed"]
 | 
				
			||||||
 | 
					        text_embed = embed_outputs["text_embed"]
 | 
				
			||||||
 | 
					        shape_embed = embed_outputs["shape_embed"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # normalized features
 | 
				
			||||||
 | 
					        shape_embed = F.normalize(shape_embed, dim=-1, p=2)
 | 
				
			||||||
 | 
					        text_embed = F.normalize(text_embed, dim=-1, p=2)
 | 
				
			||||||
 | 
					        image_embed = F.normalize(image_embed, dim=-1, p=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # B x B
 | 
				
			||||||
 | 
					        shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # B x B
 | 
				
			||||||
 | 
					        shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # shape reconstruction
 | 
				
			||||||
 | 
					        shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z)
 | 
				
			||||||
 | 
					        latents = self.model.shape_model.decode(shape_zq)
 | 
				
			||||||
 | 
					        geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 2. decode geometry
 | 
				
			||||||
 | 
					        mesh_v_f, has_surface = extract_geometry(
 | 
				
			||||||
 | 
					            geometric_func=geometric_func,
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            batch_size=bs,
 | 
				
			||||||
 | 
					            bounds=bounds,
 | 
				
			||||||
 | 
					            octree_depth=octree_depth,
 | 
				
			||||||
 | 
					            num_chunks=num_chunks,
 | 
				
			||||||
 | 
					            disable=not self.zero_rank
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 3. decode texture
 | 
				
			||||||
 | 
					        for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
 | 
				
			||||||
 | 
					            if not is_surface:
 | 
				
			||||||
 | 
					                outputs.append(None)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out = AlignedMeshOutput()
 | 
				
			||||||
 | 
					            out.mesh_v = mesh_v
 | 
				
			||||||
 | 
					            out.mesh_f = mesh_f
 | 
				
			||||||
 | 
					            out.surface = surface[i].cpu().numpy()
 | 
				
			||||||
 | 
					            out.image = image[i].cpu().numpy()
 | 
				
			||||||
 | 
					            if description is not None:
 | 
				
			||||||
 | 
					                out.text = description[i]
 | 
				
			||||||
 | 
					            out.shape_text_similarity = shape_text_similarity[i, i]
 | 
				
			||||||
 | 
					            out.shape_image_similarity = shape_image_similarity[i, i]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            outputs.append(out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def latent2mesh(self,
 | 
				
			||||||
 | 
					                    latents: torch.FloatTensor,
 | 
				
			||||||
 | 
					                    bounds: Union[Tuple[float], List[float], float] = 1.1,
 | 
				
			||||||
 | 
					                    octree_depth: int = 7,
 | 
				
			||||||
 | 
					                    num_chunks: int = 10000) -> List[Latent2MeshOutput]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            latents: [bs, num_latents, dim]
 | 
				
			||||||
 | 
					            bounds:
 | 
				
			||||||
 | 
					            octree_depth:
 | 
				
			||||||
 | 
					            num_chunks:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            mesh_outputs (List[MeshOutput]): the mesh outputs list.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        outputs = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        geometric_func = partial(self.model.shape_model.query_geometry, latents=latents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 2. decode geometry
 | 
				
			||||||
 | 
					        device = latents.device
 | 
				
			||||||
 | 
					        mesh_v_f, has_surface = extract_geometry(
 | 
				
			||||||
 | 
					            geometric_func=geometric_func,
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            batch_size=len(latents),
 | 
				
			||||||
 | 
					            bounds=bounds,
 | 
				
			||||||
 | 
					            octree_depth=octree_depth,
 | 
				
			||||||
 | 
					            num_chunks=num_chunks,
 | 
				
			||||||
 | 
					            disable=not self.zero_rank
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 3. decode texture
 | 
				
			||||||
 | 
					        for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
 | 
				
			||||||
 | 
					            if not is_surface:
 | 
				
			||||||
 | 
					                outputs.append(None)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out = Latent2MeshOutput()
 | 
				
			||||||
 | 
					            out.mesh_v = mesh_v
 | 
				
			||||||
 | 
					            out.mesh_f = mesh_f
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            outputs.append(out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										114
									
								
								primitive_anything/michelangelo/models/tsal/clip_asl_module.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										114
									
								
								primitive_anything/michelangelo/models/tsal/clip_asl_module.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,114 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					from einops import rearrange
 | 
				
			||||||
 | 
					from transformers import CLIPModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .tsal_base import AlignedShapeAsLatentModule
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 shape_model,
 | 
				
			||||||
 | 
					                 projection_dim=768):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.shape_model = shape_model
 | 
				
			||||||
 | 
					        self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, projection_dim))
 | 
				
			||||||
 | 
					        nn.init.normal_(self.shape_projection, std=projection_dim ** -0.5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_shape_model_only(self):
 | 
				
			||||||
 | 
					        self.clip_model = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode_shape_embed(self, surface, return_latents: bool = False):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            surface (torch.FloatTensor): [bs, n, 3 + c]
 | 
				
			||||||
 | 
					            return_latents (bool):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            x (torch.FloatTensor): [bs, projection_dim]
 | 
				
			||||||
 | 
					            shape_latents (torch.FloatTensor): [bs, m, d]
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pc = surface[..., 0:3]
 | 
				
			||||||
 | 
					        feats = surface[..., 3:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats)
 | 
				
			||||||
 | 
					        x = shape_embed @ self.shape_projection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if return_latents:
 | 
				
			||||||
 | 
					            return x, shape_latents
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode_image_embed(self, image):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            image (torch.FloatTensor): [bs, 3, h, w]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            x (torch.FloatTensor): [bs, projection_dim]
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = self.clip_model.get_image_features(image)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode_text_embed(self, text):
 | 
				
			||||||
 | 
					        x = self.clip_model.get_text_features(text)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, surface, image, text):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            surface (torch.FloatTensor):
 | 
				
			||||||
 | 
					            image (torch.FloatTensor): [bs, 3, 224, 224]
 | 
				
			||||||
 | 
					            text (torch.LongTensor): [bs, num_templates, 77]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            embed_outputs (dict): the embedding outputs, and it contains:
 | 
				
			||||||
 | 
					                - image_embed (torch.FloatTensor):
 | 
				
			||||||
 | 
					                - text_embed (torch.FloatTensor):
 | 
				
			||||||
 | 
					                - shape_embed (torch.FloatTensor):
 | 
				
			||||||
 | 
					                - logit_scale (float):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # # text embedding
 | 
				
			||||||
 | 
					        # text_embed_all = []
 | 
				
			||||||
 | 
					        # for i in range(text.shape[0]):
 | 
				
			||||||
 | 
					        #     text_for_one_sample = text[i]
 | 
				
			||||||
 | 
					        #     text_embed = self.encode_text_embed(text_for_one_sample)
 | 
				
			||||||
 | 
					        #     text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					        #     text_embed = text_embed.mean(dim=0)
 | 
				
			||||||
 | 
					        #     text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					        #     text_embed_all.append(text_embed)
 | 
				
			||||||
 | 
					        # text_embed_all = torch.stack(text_embed_all)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        b = text.shape[0]
 | 
				
			||||||
 | 
					        text_tokens = rearrange(text, "b t l -> (b t) l")
 | 
				
			||||||
 | 
					        text_embed = self.encode_text_embed(text_tokens)
 | 
				
			||||||
 | 
					        text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
 | 
				
			||||||
 | 
					        text_embed = text_embed.mean(dim=1)
 | 
				
			||||||
 | 
					        text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # image embedding
 | 
				
			||||||
 | 
					        image_embed = self.encode_image_embed(image)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # shape embedding
 | 
				
			||||||
 | 
					        shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embed_outputs = {
 | 
				
			||||||
 | 
					            "image_embed": image_embed,
 | 
				
			||||||
 | 
					            "text_embed": text_embed,
 | 
				
			||||||
 | 
					            "shape_embed": shape_embed,
 | 
				
			||||||
 | 
					            "logit_scale": self.clip_model.logit_scale.exp()
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return embed_outputs, shape_latents
 | 
				
			||||||
							
								
								
									
										80
									
								
								primitive_anything/michelangelo/models/tsal/inference_utils.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										80
									
								
								primitive_anything/michelangelo/models/tsal/inference_utils.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,80 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					from einops import repeat
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from typing import Callable, Tuple, List, Union, Optional
 | 
				
			||||||
 | 
					from skimage import measure
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ...graphics.primitives import generate_dense_grid_points
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@torch.no_grad()
 | 
				
			||||||
 | 
					def extract_geometry(geometric_func: Callable,
 | 
				
			||||||
 | 
					                     device: torch.device,
 | 
				
			||||||
 | 
					                     batch_size: int = 1,
 | 
				
			||||||
 | 
					                     bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
 | 
				
			||||||
 | 
					                     octree_depth: int = 7,
 | 
				
			||||||
 | 
					                     num_chunks: int = 10000,
 | 
				
			||||||
 | 
					                     disable: bool = True):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        geometric_func:
 | 
				
			||||||
 | 
					        device:
 | 
				
			||||||
 | 
					        bounds:
 | 
				
			||||||
 | 
					        octree_depth:
 | 
				
			||||||
 | 
					        batch_size:
 | 
				
			||||||
 | 
					        num_chunks:
 | 
				
			||||||
 | 
					        disable:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if isinstance(bounds, float):
 | 
				
			||||||
 | 
					        bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bbox_min = np.array(bounds[0:3])
 | 
				
			||||||
 | 
					    bbox_max = np.array(bounds[3:6])
 | 
				
			||||||
 | 
					    bbox_size = bbox_max - bbox_min
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    xyz_samples, grid_size, length = generate_dense_grid_points(
 | 
				
			||||||
 | 
					        bbox_min=bbox_min,
 | 
				
			||||||
 | 
					        bbox_max=bbox_max,
 | 
				
			||||||
 | 
					        octree_depth=octree_depth,
 | 
				
			||||||
 | 
					        indexing="ij"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    xyz_samples = torch.FloatTensor(xyz_samples)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    batch_logits = []
 | 
				
			||||||
 | 
					    for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
 | 
				
			||||||
 | 
					                      desc="Implicit Function:", disable=disable, leave=False):
 | 
				
			||||||
 | 
					        queries = xyz_samples[start: start + num_chunks, :].to(device)
 | 
				
			||||||
 | 
					        batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logits = geometric_func(batch_queries)
 | 
				
			||||||
 | 
					        batch_logits.append(logits.cpu())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mesh_v_f = []
 | 
				
			||||||
 | 
					    has_surface = np.zeros((batch_size,), dtype=np.bool_)
 | 
				
			||||||
 | 
					    for i in range(batch_size):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner")
 | 
				
			||||||
 | 
					            vertices = vertices / grid_size * bbox_size + bbox_min
 | 
				
			||||||
 | 
					            # vertices[:, [0, 1]] = vertices[:, [1, 0]]
 | 
				
			||||||
 | 
					            mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces)))
 | 
				
			||||||
 | 
					            has_surface[i] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        except ValueError:
 | 
				
			||||||
 | 
					            mesh_v_f.append((None, None))
 | 
				
			||||||
 | 
					            has_surface[i] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        except RuntimeError:
 | 
				
			||||||
 | 
					            mesh_v_f.append((None, None))
 | 
				
			||||||
 | 
					            has_surface[i] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return mesh_v_f, has_surface
 | 
				
			||||||
							
								
								
									
										303
									
								
								primitive_anything/michelangelo/models/tsal/loss.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										303
									
								
								primitive_anything/michelangelo/models/tsal/loss.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,303 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional, Tuple, Dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..modules.distributions import DiagonalGaussianDistribution
 | 
				
			||||||
 | 
					from ...utils.eval import compute_psnr
 | 
				
			||||||
 | 
					from ...utils import misc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KLNearFar(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 near_weight: float = 0.1,
 | 
				
			||||||
 | 
					                 kl_weight: float = 1.0,
 | 
				
			||||||
 | 
					                 num_near_samples: Optional[int] = None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.near_weight = near_weight
 | 
				
			||||||
 | 
					        self.kl_weight = kl_weight
 | 
				
			||||||
 | 
					        self.num_near_samples = num_near_samples
 | 
				
			||||||
 | 
					        self.geo_criterion = nn.BCEWithLogitsLoss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self,
 | 
				
			||||||
 | 
					                posteriors: Optional[DiagonalGaussianDistribution],
 | 
				
			||||||
 | 
					                logits: torch.FloatTensor,
 | 
				
			||||||
 | 
					                labels: torch.FloatTensor,
 | 
				
			||||||
 | 
					                split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
 | 
				
			||||||
 | 
					            logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
 | 
				
			||||||
 | 
					            labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
 | 
				
			||||||
 | 
					            split (str):
 | 
				
			||||||
 | 
					            **kwargs:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            loss (torch.Tensor): (,)
 | 
				
			||||||
 | 
					            log (dict):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.num_near_samples is None:
 | 
				
			||||||
 | 
					            num_vol = logits.shape[1] // 2
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            num_vol = logits.shape[1] - self.num_near_samples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        vol_logits = logits[:, 0:num_vol]
 | 
				
			||||||
 | 
					        vol_labels = labels[:, 0:num_vol]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        near_logits = logits[:, num_vol:]
 | 
				
			||||||
 | 
					        near_labels = labels[:, num_vol:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # occupancy loss
 | 
				
			||||||
 | 
					        # vol_bce = self.geo_criterion(vol_logits, vol_labels)
 | 
				
			||||||
 | 
					        # near_bce = self.geo_criterion(near_logits, near_labels)
 | 
				
			||||||
 | 
					        vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
 | 
				
			||||||
 | 
					        near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if posteriors is None:
 | 
				
			||||||
 | 
					            kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            kl_loss = posteriors.kl(dims=(1, 2))
 | 
				
			||||||
 | 
					            kl_loss = torch.mean(kl_loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            preds = logits >= 0
 | 
				
			||||||
 | 
					            accuracy = (preds == labels).float()
 | 
				
			||||||
 | 
					            accuracy = accuracy.mean()
 | 
				
			||||||
 | 
					            pos_ratio = torch.mean(labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        log = {
 | 
				
			||||||
 | 
					            "{}/total_loss".format(split): loss.clone().detach(),
 | 
				
			||||||
 | 
					            "{}/near".format(split): near_bce.detach(),
 | 
				
			||||||
 | 
					            "{}/far".format(split): vol_bce.detach(),
 | 
				
			||||||
 | 
					            "{}/kl".format(split): kl_loss.detach(),
 | 
				
			||||||
 | 
					            "{}/accuracy".format(split): accuracy,
 | 
				
			||||||
 | 
					            "{}/pos_ratio".format(split): pos_ratio
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if posteriors is not None:
 | 
				
			||||||
 | 
					            log[f"{split}/mean"] = posteriors.mean.mean().detach()
 | 
				
			||||||
 | 
					            log[f"{split}/std_mean"] = posteriors.std.mean().detach()
 | 
				
			||||||
 | 
					            log[f"{split}/std_max"] = posteriors.std.max().detach()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return loss, log
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KLNearFarColor(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 near_weight: float = 0.1,
 | 
				
			||||||
 | 
					                 kl_weight: float = 1.0,
 | 
				
			||||||
 | 
					                 color_weight: float = 1.0,
 | 
				
			||||||
 | 
					                 color_criterion: str = "mse",
 | 
				
			||||||
 | 
					                 num_near_samples: Optional[int] = None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.color_weight = color_weight
 | 
				
			||||||
 | 
					        self.near_weight = near_weight
 | 
				
			||||||
 | 
					        self.kl_weight = kl_weight
 | 
				
			||||||
 | 
					        self.num_near_samples = num_near_samples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if color_criterion == "mse":
 | 
				
			||||||
 | 
					            self.color_criterion = nn.MSELoss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif color_criterion == "l1":
 | 
				
			||||||
 | 
					            self.color_criterion = nn.L1Loss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ValueError(f"{color_criterion} must be [`mse`, `l1`].")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.geo_criterion = nn.BCEWithLogitsLoss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self,
 | 
				
			||||||
 | 
					                posteriors: Optional[DiagonalGaussianDistribution],
 | 
				
			||||||
 | 
					                logits: torch.FloatTensor,
 | 
				
			||||||
 | 
					                labels: torch.FloatTensor,
 | 
				
			||||||
 | 
					                pred_colors: torch.FloatTensor,
 | 
				
			||||||
 | 
					                gt_colors: torch.FloatTensor,
 | 
				
			||||||
 | 
					                split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            posteriors (DiagonalGaussianDistribution or torch.distributions.Normal):
 | 
				
			||||||
 | 
					            logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points;
 | 
				
			||||||
 | 
					            labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points;
 | 
				
			||||||
 | 
					            pred_colors (torch.FloatTensor): [B, M, 3]
 | 
				
			||||||
 | 
					            gt_colors (torch.FloatTensor): [B, M, 3]
 | 
				
			||||||
 | 
					            split (str):
 | 
				
			||||||
 | 
					            **kwargs:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            loss (torch.Tensor): (,)
 | 
				
			||||||
 | 
					            log (dict):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.num_near_samples is None:
 | 
				
			||||||
 | 
					            num_vol = logits.shape[1] // 2
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            num_vol = logits.shape[1] - self.num_near_samples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        vol_logits = logits[:, 0:num_vol]
 | 
				
			||||||
 | 
					        vol_labels = labels[:, 0:num_vol]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        near_logits = logits[:, num_vol:]
 | 
				
			||||||
 | 
					        near_labels = labels[:, num_vol:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # occupancy loss
 | 
				
			||||||
 | 
					        # vol_bce = self.geo_criterion(vol_logits, vol_labels)
 | 
				
			||||||
 | 
					        # near_bce = self.geo_criterion(near_logits, near_labels)
 | 
				
			||||||
 | 
					        vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
 | 
				
			||||||
 | 
					        near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # surface color loss
 | 
				
			||||||
 | 
					        color = self.color_criterion(pred_colors, gt_colors)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if posteriors is None:
 | 
				
			||||||
 | 
					            kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            kl_loss = posteriors.kl(dims=(1, 2))
 | 
				
			||||||
 | 
					            kl_loss = torch.mean(kl_loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            preds = logits >= 0
 | 
				
			||||||
 | 
					            accuracy = (preds == labels).float()
 | 
				
			||||||
 | 
					            accuracy = accuracy.mean()
 | 
				
			||||||
 | 
					            psnr = compute_psnr(pred_colors, gt_colors)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        log = {
 | 
				
			||||||
 | 
					            "{}/total_loss".format(split): loss.clone().detach(),
 | 
				
			||||||
 | 
					            "{}/near".format(split): near_bce.detach(),
 | 
				
			||||||
 | 
					            "{}/far".format(split): vol_bce.detach(),
 | 
				
			||||||
 | 
					            "{}/color".format(split): color.detach(),
 | 
				
			||||||
 | 
					            "{}/kl".format(split): kl_loss.detach(),
 | 
				
			||||||
 | 
					            "{}/psnr".format(split): psnr.detach(),
 | 
				
			||||||
 | 
					            "{}/accuracy".format(split): accuracy
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return loss, log
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ContrastKLNearFar(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 contrast_weight: float = 1.0,
 | 
				
			||||||
 | 
					                 near_weight: float = 0.1,
 | 
				
			||||||
 | 
					                 kl_weight: float = 1.0,
 | 
				
			||||||
 | 
					                 num_near_samples: Optional[int] = None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.labels = None
 | 
				
			||||||
 | 
					        self.last_local_batch_size = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.contrast_weight = contrast_weight
 | 
				
			||||||
 | 
					        self.near_weight = near_weight
 | 
				
			||||||
 | 
					        self.kl_weight = kl_weight
 | 
				
			||||||
 | 
					        self.num_near_samples = num_near_samples
 | 
				
			||||||
 | 
					        self.geo_criterion = nn.BCEWithLogitsLoss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self,
 | 
				
			||||||
 | 
					                shape_embed: torch.FloatTensor,
 | 
				
			||||||
 | 
					                text_embed: torch.FloatTensor,
 | 
				
			||||||
 | 
					                image_embed: torch.FloatTensor,
 | 
				
			||||||
 | 
					                logit_scale: torch.FloatTensor,
 | 
				
			||||||
 | 
					                posteriors: Optional[DiagonalGaussianDistribution],
 | 
				
			||||||
 | 
					                shape_logits: torch.FloatTensor,
 | 
				
			||||||
 | 
					                shape_labels: torch.FloatTensor,
 | 
				
			||||||
 | 
					                split: Optional[str] = "train", **kwargs):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        local_batch_size = shape_embed.size(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if local_batch_size != self.last_local_batch_size:
 | 
				
			||||||
 | 
					            self.labels = local_batch_size * misc.get_rank() + torch.arange(
 | 
				
			||||||
 | 
					                local_batch_size, device=shape_embed.device
 | 
				
			||||||
 | 
					            ).long()
 | 
				
			||||||
 | 
					            self.last_local_batch_size = local_batch_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # normalized features
 | 
				
			||||||
 | 
					        shape_embed = F.normalize(shape_embed, dim=-1, p=2)
 | 
				
			||||||
 | 
					        text_embed = F.normalize(text_embed, dim=-1, p=2)
 | 
				
			||||||
 | 
					        image_embed = F.normalize(image_embed, dim=-1, p=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # gather features from all GPUs
 | 
				
			||||||
 | 
					        shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch(
 | 
				
			||||||
 | 
					            [shape_embed, text_embed, image_embed]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # cosine similarity as logits
 | 
				
			||||||
 | 
					        logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t()
 | 
				
			||||||
 | 
					        logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t()
 | 
				
			||||||
 | 
					        logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t()
 | 
				
			||||||
 | 
					        logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t()
 | 
				
			||||||
 | 
					        contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) +
 | 
				
			||||||
 | 
					                         F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \
 | 
				
			||||||
 | 
					                        (F.cross_entropy(logits_per_shape_image, self.labels) +
 | 
				
			||||||
 | 
					                         F.cross_entropy(logits_per_image_shape, self.labels)) / 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # shape reconstruction
 | 
				
			||||||
 | 
					        if self.num_near_samples is None:
 | 
				
			||||||
 | 
					            num_vol = shape_logits.shape[1] // 2
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            num_vol = shape_logits.shape[1] - self.num_near_samples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        vol_logits = shape_logits[:, 0:num_vol]
 | 
				
			||||||
 | 
					        vol_labels = shape_labels[:, 0:num_vol]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        near_logits = shape_logits[:, num_vol:]
 | 
				
			||||||
 | 
					        near_labels = shape_labels[:, num_vol:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # occupancy loss
 | 
				
			||||||
 | 
					        vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float())
 | 
				
			||||||
 | 
					        near_bce = self.geo_criterion(near_logits.float(), near_labels.float())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if posteriors is None:
 | 
				
			||||||
 | 
					            kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            kl_loss = posteriors.kl(dims=(1, 2))
 | 
				
			||||||
 | 
					            kl_loss = torch.mean(kl_loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # compute accuracy
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            pred = torch.argmax(logits_per_shape_text, dim=-1)
 | 
				
			||||||
 | 
					            correct = pred.eq(self.labels).sum()
 | 
				
			||||||
 | 
					            shape_text_acc = 100 * correct / local_batch_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            pred = torch.argmax(logits_per_shape_image, dim=-1)
 | 
				
			||||||
 | 
					            correct = pred.eq(self.labels).sum()
 | 
				
			||||||
 | 
					            shape_image_acc = 100 * correct / local_batch_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            preds = shape_logits >= 0
 | 
				
			||||||
 | 
					            accuracy = (preds == shape_labels).float()
 | 
				
			||||||
 | 
					            accuracy = accuracy.mean()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            log = {
 | 
				
			||||||
 | 
					                "{}/contrast".format(split): contrast_loss.clone().detach(),
 | 
				
			||||||
 | 
					                "{}/near".format(split): near_bce.detach(),
 | 
				
			||||||
 | 
					                "{}/far".format(split): vol_bce.detach(),
 | 
				
			||||||
 | 
					                "{}/kl".format(split): kl_loss.detach(),
 | 
				
			||||||
 | 
					                "{}/shape_text_acc".format(split): shape_text_acc,
 | 
				
			||||||
 | 
					                "{}/shape_image_acc".format(split): shape_image_acc,
 | 
				
			||||||
 | 
					                "{}/total_loss".format(split): loss.clone().detach(),
 | 
				
			||||||
 | 
					                "{}/accuracy".format(split): accuracy,
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if posteriors is not None:
 | 
				
			||||||
 | 
					                log[f"{split}/mean"] = posteriors.mean.mean().detach()
 | 
				
			||||||
 | 
					                log[f"{split}/std_mean"] = posteriors.std.mean().detach()
 | 
				
			||||||
 | 
					                log[f"{split}/std_max"] = posteriors.std.max().detach()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return loss, log
 | 
				
			||||||
							
								
								
									
										423
									
								
								primitive_anything/michelangelo/models/tsal/sal_perceiver.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										423
									
								
								primitive_anything/michelangelo/models/tsal/sal_perceiver.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,423 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from einops import repeat
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..modules import checkpoint
 | 
				
			||||||
 | 
					from ..modules.embedder import FourierEmbedder
 | 
				
			||||||
 | 
					from ..modules.distributions import DiagonalGaussianDistribution
 | 
				
			||||||
 | 
					from ..modules.transformer_blocks import (
 | 
				
			||||||
 | 
					    ResidualCrossAttentionBlock,
 | 
				
			||||||
 | 
					    Transformer
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .tsal_base import ShapeAsLatentModule
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CrossAttentionEncoder(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 device: Optional[torch.device],
 | 
				
			||||||
 | 
					                 dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					                 num_latents: int,
 | 
				
			||||||
 | 
					                 fourier_embedder: FourierEmbedder,
 | 
				
			||||||
 | 
					                 point_feats: int,
 | 
				
			||||||
 | 
					                 width: int,
 | 
				
			||||||
 | 
					                 heads: int,
 | 
				
			||||||
 | 
					                 layers: int,
 | 
				
			||||||
 | 
					                 init_scale: float = 0.25,
 | 
				
			||||||
 | 
					                 qkv_bias: bool = True,
 | 
				
			||||||
 | 
					                 flash: bool = False,
 | 
				
			||||||
 | 
					                 use_ln_post: bool = False,
 | 
				
			||||||
 | 
					                 use_checkpoint: bool = False):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.use_checkpoint = use_checkpoint
 | 
				
			||||||
 | 
					        self.num_latents = num_latents
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.fourier_embedder = fourier_embedder
 | 
				
			||||||
 | 
					        self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.cross_attn = ResidualCrossAttentionBlock(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					            flash=flash,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.self_attn = Transformer(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            n_ctx=num_latents,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            layers=layers,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					            flash=flash,
 | 
				
			||||||
 | 
					            use_checkpoint=False
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_ln_post:
 | 
				
			||||||
 | 
					            self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.ln_post = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _forward(self, pc, feats):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            pc (torch.FloatTensor): [B, N, 3]
 | 
				
			||||||
 | 
					            feats (torch.FloatTensor or None): [B, N, C]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        bs = pc.shape[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        data = self.fourier_embedder(pc)
 | 
				
			||||||
 | 
					        if feats is not None:
 | 
				
			||||||
 | 
					            data = torch.cat([data, feats], dim=-1)
 | 
				
			||||||
 | 
					        data = self.input_proj(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query = repeat(self.query, "m c -> b m c", b=bs)
 | 
				
			||||||
 | 
					        latents = self.cross_attn(query, data)
 | 
				
			||||||
 | 
					        latents = self.self_attn(latents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.ln_post is not None:
 | 
				
			||||||
 | 
					            latents = self.ln_post(latents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return latents, pc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            pc (torch.FloatTensor): [B, N, 3]
 | 
				
			||||||
 | 
					            feats (torch.FloatTensor or None): [B, N, C]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            dict
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CrossAttentionDecoder(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 device: Optional[torch.device],
 | 
				
			||||||
 | 
					                 dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					                 num_latents: int,
 | 
				
			||||||
 | 
					                 out_channels: int,
 | 
				
			||||||
 | 
					                 fourier_embedder: FourierEmbedder,
 | 
				
			||||||
 | 
					                 width: int,
 | 
				
			||||||
 | 
					                 heads: int,
 | 
				
			||||||
 | 
					                 init_scale: float = 0.25,
 | 
				
			||||||
 | 
					                 qkv_bias: bool = True,
 | 
				
			||||||
 | 
					                 flash: bool = False,
 | 
				
			||||||
 | 
					                 use_checkpoint: bool = False):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.use_checkpoint = use_checkpoint
 | 
				
			||||||
 | 
					        self.fourier_embedder = fourier_embedder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.cross_attn_decoder = ResidualCrossAttentionBlock(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            n_data=num_latents,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					            flash=flash
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					        self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
 | 
				
			||||||
 | 
					        queries = self.query_proj(self.fourier_embedder(queries))
 | 
				
			||||||
 | 
					        x = self.cross_attn_decoder(queries, latents)
 | 
				
			||||||
 | 
					        x = self.ln_post(x)
 | 
				
			||||||
 | 
					        x = self.output_proj(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
 | 
				
			||||||
 | 
					        return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ShapeAsLatentPerceiver(ShapeAsLatentModule):
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 device: Optional[torch.device],
 | 
				
			||||||
 | 
					                 dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					                 num_latents: int,
 | 
				
			||||||
 | 
					                 point_feats: int = 0,
 | 
				
			||||||
 | 
					                 embed_dim: int = 0,
 | 
				
			||||||
 | 
					                 num_freqs: int = 8,
 | 
				
			||||||
 | 
					                 include_pi: bool = True,
 | 
				
			||||||
 | 
					                 width: int,
 | 
				
			||||||
 | 
					                 heads: int,
 | 
				
			||||||
 | 
					                 num_encoder_layers: int,
 | 
				
			||||||
 | 
					                 num_decoder_layers: int,
 | 
				
			||||||
 | 
					                 init_scale: float = 0.25,
 | 
				
			||||||
 | 
					                 qkv_bias: bool = True,
 | 
				
			||||||
 | 
					                 flash: bool = False,
 | 
				
			||||||
 | 
					                 use_ln_post: bool = False,
 | 
				
			||||||
 | 
					                 use_checkpoint: bool = False):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.use_checkpoint = use_checkpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.num_latents = num_latents
 | 
				
			||||||
 | 
					        self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        init_scale = init_scale * math.sqrt(1.0 / width)
 | 
				
			||||||
 | 
					        self.encoder = CrossAttentionEncoder(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            fourier_embedder=self.fourier_embedder,
 | 
				
			||||||
 | 
					            num_latents=num_latents,
 | 
				
			||||||
 | 
					            point_feats=point_feats,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            layers=num_encoder_layers,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					            flash=flash,
 | 
				
			||||||
 | 
					            use_ln_post=use_ln_post,
 | 
				
			||||||
 | 
					            use_checkpoint=use_checkpoint
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.embed_dim = embed_dim
 | 
				
			||||||
 | 
					        if embed_dim > 0:
 | 
				
			||||||
 | 
					            # VAE embed
 | 
				
			||||||
 | 
					            self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					            self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype)
 | 
				
			||||||
 | 
					            self.latent_shape = (num_latents, embed_dim)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.latent_shape = (num_latents, width)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.transformer = Transformer(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            n_ctx=num_latents,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            layers=num_decoder_layers,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					            flash=flash,
 | 
				
			||||||
 | 
					            use_checkpoint=use_checkpoint
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # geometry decoder
 | 
				
			||||||
 | 
					        self.geo_decoder = CrossAttentionDecoder(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            fourier_embedder=self.fourier_embedder,
 | 
				
			||||||
 | 
					            out_channels=1,
 | 
				
			||||||
 | 
					            num_latents=num_latents,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					            flash=flash,
 | 
				
			||||||
 | 
					            use_checkpoint=use_checkpoint
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self,
 | 
				
			||||||
 | 
					               pc: torch.FloatTensor,
 | 
				
			||||||
 | 
					               feats: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					               sample_posterior: bool = True):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            pc (torch.FloatTensor): [B, N, 3]
 | 
				
			||||||
 | 
					            feats (torch.FloatTensor or None): [B, N, C]
 | 
				
			||||||
 | 
					            sample_posterior (bool):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            latents (torch.FloatTensor)
 | 
				
			||||||
 | 
					            center_pos (torch.FloatTensor or None):
 | 
				
			||||||
 | 
					            posterior (DiagonalGaussianDistribution or None):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        latents, center_pos = self.encoder(pc, feats)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        posterior = None
 | 
				
			||||||
 | 
					        if self.embed_dim > 0:
 | 
				
			||||||
 | 
					            moments = self.pre_kl(latents)
 | 
				
			||||||
 | 
					            posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if sample_posterior:
 | 
				
			||||||
 | 
					                latents = posterior.sample()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                latents = posterior.mode()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return latents, center_pos, posterior
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decode(self, latents: torch.FloatTensor):
 | 
				
			||||||
 | 
					        latents = self.post_kl(latents)
 | 
				
			||||||
 | 
					        return self.transformer(latents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
 | 
				
			||||||
 | 
					        logits = self.geo_decoder(queries, latents).squeeze(-1)
 | 
				
			||||||
 | 
					        return logits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self,
 | 
				
			||||||
 | 
					                pc: torch.FloatTensor,
 | 
				
			||||||
 | 
					                feats: torch.FloatTensor,
 | 
				
			||||||
 | 
					                volume_queries: torch.FloatTensor,
 | 
				
			||||||
 | 
					                sample_posterior: bool = True):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            pc (torch.FloatTensor): [B, N, 3]
 | 
				
			||||||
 | 
					            feats (torch.FloatTensor or None): [B, N, C]
 | 
				
			||||||
 | 
					            volume_queries (torch.FloatTensor): [B, P, 3]
 | 
				
			||||||
 | 
					            sample_posterior (bool):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            logits (torch.FloatTensor): [B, P]
 | 
				
			||||||
 | 
					            center_pos (torch.FloatTensor): [B, M, 3]
 | 
				
			||||||
 | 
					            posterior (DiagonalGaussianDistribution or None).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        latents = self.decode(latents)
 | 
				
			||||||
 | 
					        logits = self.query_geometry(volume_queries, latents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return logits, center_pos, posterior
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 device: Optional[torch.device],
 | 
				
			||||||
 | 
					                 dtype: Optional[torch.dtype],
 | 
				
			||||||
 | 
					                 num_latents: int,
 | 
				
			||||||
 | 
					                 point_feats: int = 0,
 | 
				
			||||||
 | 
					                 embed_dim: int = 0,
 | 
				
			||||||
 | 
					                 num_freqs: int = 8,
 | 
				
			||||||
 | 
					                 include_pi: bool = True,
 | 
				
			||||||
 | 
					                 width: int,
 | 
				
			||||||
 | 
					                 heads: int,
 | 
				
			||||||
 | 
					                 num_encoder_layers: int,
 | 
				
			||||||
 | 
					                 num_decoder_layers: int,
 | 
				
			||||||
 | 
					                 init_scale: float = 0.25,
 | 
				
			||||||
 | 
					                 qkv_bias: bool = True,
 | 
				
			||||||
 | 
					                 flash: bool = False,
 | 
				
			||||||
 | 
					                 use_ln_post: bool = False,
 | 
				
			||||||
 | 
					                 use_checkpoint: bool = False):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__(
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            dtype=dtype,
 | 
				
			||||||
 | 
					            num_latents=1 + num_latents,
 | 
				
			||||||
 | 
					            point_feats=point_feats,
 | 
				
			||||||
 | 
					            embed_dim=embed_dim,
 | 
				
			||||||
 | 
					            num_freqs=num_freqs,
 | 
				
			||||||
 | 
					            include_pi=include_pi,
 | 
				
			||||||
 | 
					            width=width,
 | 
				
			||||||
 | 
					            heads=heads,
 | 
				
			||||||
 | 
					            num_encoder_layers=num_encoder_layers,
 | 
				
			||||||
 | 
					            num_decoder_layers=num_decoder_layers,
 | 
				
			||||||
 | 
					            init_scale=init_scale,
 | 
				
			||||||
 | 
					            qkv_bias=qkv_bias,
 | 
				
			||||||
 | 
					            flash=flash,
 | 
				
			||||||
 | 
					            use_ln_post=use_ln_post,
 | 
				
			||||||
 | 
					            use_checkpoint=use_checkpoint
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.width = width
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self,
 | 
				
			||||||
 | 
					               pc: torch.FloatTensor,
 | 
				
			||||||
 | 
					               feats: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					               sample_posterior: bool = True):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            pc (torch.FloatTensor): [B, N, 3]
 | 
				
			||||||
 | 
					            feats (torch.FloatTensor or None): [B, N, c]
 | 
				
			||||||
 | 
					            sample_posterior (bool):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            shape_embed (torch.FloatTensor)
 | 
				
			||||||
 | 
					            kl_embed (torch.FloatTensor):
 | 
				
			||||||
 | 
					            posterior (DiagonalGaussianDistribution or None):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        shape_embed, latents = self.encode_latents(pc, feats)
 | 
				
			||||||
 | 
					        kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return shape_embed, kl_embed, posterior
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode_latents(self,
 | 
				
			||||||
 | 
					                       pc: torch.FloatTensor,
 | 
				
			||||||
 | 
					                       feats: Optional[torch.FloatTensor] = None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x, _ = self.encoder(pc, feats)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        shape_embed = x[:, 0]
 | 
				
			||||||
 | 
					        latents = x[:, 1:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return shape_embed, latents
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
 | 
				
			||||||
 | 
					        posterior = None
 | 
				
			||||||
 | 
					        if self.embed_dim > 0:
 | 
				
			||||||
 | 
					            moments = self.pre_kl(latents)
 | 
				
			||||||
 | 
					            posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if sample_posterior:
 | 
				
			||||||
 | 
					                kl_embed = posterior.sample()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                kl_embed = posterior.mode()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            kl_embed = latents
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return kl_embed, posterior
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self,
 | 
				
			||||||
 | 
					                pc: torch.FloatTensor,
 | 
				
			||||||
 | 
					                feats: torch.FloatTensor,
 | 
				
			||||||
 | 
					                volume_queries: torch.FloatTensor,
 | 
				
			||||||
 | 
					                sample_posterior: bool = True):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            pc (torch.FloatTensor): [B, N, 3]
 | 
				
			||||||
 | 
					            feats (torch.FloatTensor or None): [B, N, C]
 | 
				
			||||||
 | 
					            volume_queries (torch.FloatTensor): [B, P, 3]
 | 
				
			||||||
 | 
					            sample_posterior (bool):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            shape_embed (torch.FloatTensor): [B, projection_dim]
 | 
				
			||||||
 | 
					            logits (torch.FloatTensor): [B, M]
 | 
				
			||||||
 | 
					            posterior (DiagonalGaussianDistribution or None).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        latents = self.decode(kl_embed)
 | 
				
			||||||
 | 
					        logits = self.query_geometry(volume_queries, latents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return shape_embed, logits, posterior
 | 
				
			||||||
							
								
								
									
										290
									
								
								primitive_anything/michelangelo/models/tsal/sal_pl_module.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										290
									
								
								primitive_anything/michelangelo/models/tsal/sal_pl_module.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,290 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import List, Tuple, Dict, Optional
 | 
				
			||||||
 | 
					from omegaconf import DictConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.optim import lr_scheduler
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ...utils import instantiate_from_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .inference_utils import extract_geometry
 | 
				
			||||||
 | 
					from .tsal_base import (
 | 
				
			||||||
 | 
					    ShapeAsLatentModule,
 | 
				
			||||||
 | 
					    Latent2MeshOutput,
 | 
				
			||||||
 | 
					    Point2MeshOutput
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ShapeAsLatentPLModule(pl.LightningModule):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *,
 | 
				
			||||||
 | 
					                 module_cfg,
 | 
				
			||||||
 | 
					                 loss_cfg,
 | 
				
			||||||
 | 
					                 optimizer_cfg: Optional[DictConfig] = None,
 | 
				
			||||||
 | 
					                 ckpt_path: Optional[str] = None,
 | 
				
			||||||
 | 
					                 ignore_keys: Union[Tuple[str], List[str]] = ()):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.sal: ShapeAsLatentModule = instantiate_from_config(module_cfg, device=None, dtype=None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.loss = instantiate_from_config(loss_cfg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.optimizer_cfg = optimizer_cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if ckpt_path is not None:
 | 
				
			||||||
 | 
					            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.save_hyperparameters()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def latent_shape(self):
 | 
				
			||||||
 | 
					        return self.sal.latent_shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def zero_rank(self):
 | 
				
			||||||
 | 
					        if self._trainer:
 | 
				
			||||||
 | 
					            zero_rank = self.trainer.local_rank == 0
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            zero_rank = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return zero_rank
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_from_ckpt(self, path, ignore_keys=()):
 | 
				
			||||||
 | 
					        state_dict = torch.load(path, map_location="cpu")["state_dict"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        keys = list(state_dict.keys())
 | 
				
			||||||
 | 
					        for k in keys:
 | 
				
			||||||
 | 
					            for ik in ignore_keys:
 | 
				
			||||||
 | 
					                if k.startswith(ik):
 | 
				
			||||||
 | 
					                    print("Deleting key {} from state_dict.".format(k))
 | 
				
			||||||
 | 
					                    del state_dict[k]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        missing, unexpected = self.load_state_dict(state_dict, strict=False)
 | 
				
			||||||
 | 
					        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
 | 
				
			||||||
 | 
					        if len(missing) > 0:
 | 
				
			||||||
 | 
					            print(f"Missing Keys: {missing}")
 | 
				
			||||||
 | 
					            print(f"Unexpected Keys: {unexpected}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def configure_optimizers(self) -> Tuple[List, List]:
 | 
				
			||||||
 | 
					        lr = self.learning_rate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)]
 | 
				
			||||||
 | 
					        # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.optimizer_cfg is None:
 | 
				
			||||||
 | 
					            optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
 | 
				
			||||||
 | 
					            schedulers = []
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=self.sal.parameters())
 | 
				
			||||||
 | 
					            scheduler_func = instantiate_from_config(
 | 
				
			||||||
 | 
					                self.optimizer_cfg.scheduler,
 | 
				
			||||||
 | 
					                max_decay_steps=self.trainer.max_steps,
 | 
				
			||||||
 | 
					                lr_max=lr
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            scheduler = {
 | 
				
			||||||
 | 
					                "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
 | 
				
			||||||
 | 
					                "interval": "step",
 | 
				
			||||||
 | 
					                "frequency": 1
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            optimizers = [optimizer]
 | 
				
			||||||
 | 
					            schedulers = [scheduler]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return optimizers, schedulers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self,
 | 
				
			||||||
 | 
					                pc: torch.FloatTensor,
 | 
				
			||||||
 | 
					                feats: torch.FloatTensor,
 | 
				
			||||||
 | 
					                volume_queries: torch.FloatTensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logits, center_pos, posterior = self.sal(pc, feats, volume_queries)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return posterior, logits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, surface: torch.FloatTensor, sample_posterior=True):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pc = surface[..., 0:3]
 | 
				
			||||||
 | 
					        feats = surface[..., 3:6]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        latents, center_pos, posterior = self.sal.encode(
 | 
				
			||||||
 | 
					            pc=pc, feats=feats, sample_posterior=sample_posterior
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return latents
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decode(self,
 | 
				
			||||||
 | 
					               z_q,
 | 
				
			||||||
 | 
					               bounds: Union[Tuple[float], List[float], float] = 1.1,
 | 
				
			||||||
 | 
					               octree_depth: int = 7,
 | 
				
			||||||
 | 
					               num_chunks: int = 10000) -> List[Latent2MeshOutput]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        latents = self.sal.decode(z_q)  # latents: [bs, num_latents, dim]
 | 
				
			||||||
 | 
					        outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def training_step(self, batch: Dict[str, torch.FloatTensor],
 | 
				
			||||||
 | 
					                      batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            batch (dict): the batch sample, and it contains:
 | 
				
			||||||
 | 
					                - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)]
 | 
				
			||||||
 | 
					                - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            batch_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            optimizer_idx (int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            loss (torch.FloatTensor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pc = batch["surface"][..., 0:3]
 | 
				
			||||||
 | 
					        feats = batch["surface"][..., 3:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        volume_queries = batch["geo_points"][..., 0:3]
 | 
				
			||||||
 | 
					        volume_labels = batch["geo_points"][..., -1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        posterior, logits = self(
 | 
				
			||||||
 | 
					            pc=pc, feats=feats, volume_queries=volume_queries
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="train")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
 | 
				
			||||||
 | 
					                      sync_dist=False, rank_zero_only=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return aeloss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pc = batch["surface"][..., 0:3]
 | 
				
			||||||
 | 
					        feats = batch["surface"][..., 3:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        volume_queries = batch["geo_points"][..., 0:3]
 | 
				
			||||||
 | 
					        volume_labels = batch["geo_points"][..., -1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        posterior, logits = self(
 | 
				
			||||||
 | 
					            pc=pc, feats=feats, volume_queries=volume_queries,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="val")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0],
 | 
				
			||||||
 | 
					                      sync_dist=False, rank_zero_only=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return aeloss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def point2mesh(self,
 | 
				
			||||||
 | 
					                   pc: torch.FloatTensor,
 | 
				
			||||||
 | 
					                   feats: torch.FloatTensor,
 | 
				
			||||||
 | 
					                   bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25),
 | 
				
			||||||
 | 
					                   octree_depth: int = 7,
 | 
				
			||||||
 | 
					                   num_chunks: int = 10000) -> List[Point2MeshOutput]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            pc:
 | 
				
			||||||
 | 
					            feats:
 | 
				
			||||||
 | 
					            bounds:
 | 
				
			||||||
 | 
					            octree_depth:
 | 
				
			||||||
 | 
					            num_chunks:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            mesh_outputs (List[MeshOutput]): the mesh outputs list.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        outputs = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        device = pc.device
 | 
				
			||||||
 | 
					        bs = pc.shape[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 1. point encoder + latents transformer
 | 
				
			||||||
 | 
					        latents, center_pos, posterior = self.sal.encode(pc, feats)
 | 
				
			||||||
 | 
					        latents = self.sal.decode(latents)  # latents: [bs, num_latents, dim]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        geometric_func = partial(self.sal.query_geometry, latents=latents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 2. decode geometry
 | 
				
			||||||
 | 
					        mesh_v_f, has_surface = extract_geometry(
 | 
				
			||||||
 | 
					            geometric_func=geometric_func,
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            batch_size=bs,
 | 
				
			||||||
 | 
					            bounds=bounds,
 | 
				
			||||||
 | 
					            octree_depth=octree_depth,
 | 
				
			||||||
 | 
					            num_chunks=num_chunks,
 | 
				
			||||||
 | 
					            disable=not self.zero_rank
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 3. decode texture
 | 
				
			||||||
 | 
					        for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
 | 
				
			||||||
 | 
					            if not is_surface:
 | 
				
			||||||
 | 
					                outputs.append(None)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out = Point2MeshOutput()
 | 
				
			||||||
 | 
					            out.mesh_v = mesh_v
 | 
				
			||||||
 | 
					            out.mesh_f = mesh_f
 | 
				
			||||||
 | 
					            out.pc = torch.cat([pc[i], feats[i]], dim=-1).cpu().numpy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if center_pos is not None:
 | 
				
			||||||
 | 
					                out.center = center_pos[i].cpu().numpy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            outputs.append(out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def latent2mesh(self,
 | 
				
			||||||
 | 
					                    latents: torch.FloatTensor,
 | 
				
			||||||
 | 
					                    bounds: Union[Tuple[float], List[float], float] = 1.1,
 | 
				
			||||||
 | 
					                    octree_depth: int = 7,
 | 
				
			||||||
 | 
					                    num_chunks: int = 10000) -> List[Latent2MeshOutput]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            latents: [bs, num_latents, dim]
 | 
				
			||||||
 | 
					            bounds:
 | 
				
			||||||
 | 
					            octree_depth:
 | 
				
			||||||
 | 
					            num_chunks:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            mesh_outputs (List[MeshOutput]): the mesh outputs list.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        outputs = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        geometric_func = partial(self.sal.query_geometry, latents=latents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 2. decode geometry
 | 
				
			||||||
 | 
					        device = latents.device
 | 
				
			||||||
 | 
					        mesh_v_f, has_surface = extract_geometry(
 | 
				
			||||||
 | 
					            geometric_func=geometric_func,
 | 
				
			||||||
 | 
					            device=device,
 | 
				
			||||||
 | 
					            batch_size=len(latents),
 | 
				
			||||||
 | 
					            bounds=bounds,
 | 
				
			||||||
 | 
					            octree_depth=octree_depth,
 | 
				
			||||||
 | 
					            num_chunks=num_chunks,
 | 
				
			||||||
 | 
					            disable=not self.zero_rank
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 3. decode texture
 | 
				
			||||||
 | 
					        for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
 | 
				
			||||||
 | 
					            if not is_surface:
 | 
				
			||||||
 | 
					                outputs.append(None)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out = Latent2MeshOutput()
 | 
				
			||||||
 | 
					            out.mesh_v = mesh_v
 | 
				
			||||||
 | 
					            out.mesh_f = mesh_f
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            outputs.append(out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
							
								
								
									
										121
									
								
								primitive_anything/michelangelo/models/tsal/tsal_base.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										121
									
								
								primitive_anything/michelangelo/models/tsal/tsal_base.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,121 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from typing import Tuple, List, Optional
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Point2MeshOutput(object):
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.mesh_v = None
 | 
				
			||||||
 | 
					        self.mesh_f = None
 | 
				
			||||||
 | 
					        self.center = None
 | 
				
			||||||
 | 
					        self.pc = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Latent2MeshOutput(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.mesh_v = None
 | 
				
			||||||
 | 
					        self.mesh_f = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AlignedMeshOutput(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.mesh_v = None
 | 
				
			||||||
 | 
					        self.mesh_f = None
 | 
				
			||||||
 | 
					        self.surface = None
 | 
				
			||||||
 | 
					        self.image = None
 | 
				
			||||||
 | 
					        self.text: Optional[str] = None
 | 
				
			||||||
 | 
					        self.shape_text_similarity: Optional[float] = None
 | 
				
			||||||
 | 
					        self.shape_image_similarity: Optional[float] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ShapeAsLatentPLModule(pl.LightningModule):
 | 
				
			||||||
 | 
					    latent_shape: Tuple[int]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, surface, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decode(self, z_q, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ShapeAsLatentModule(nn.Module):
 | 
				
			||||||
 | 
					    latent_shape: Tuple[int, int]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decode(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def query_geometry(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AlignedShapeAsLatentPLModule(pl.LightningModule):
 | 
				
			||||||
 | 
					    latent_shape: Tuple[int]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_shape_model_only(self):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, surface, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decode(self, z_q, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AlignedShapeAsLatentModule(nn.Module):
 | 
				
			||||||
 | 
					    shape_model: ShapeAsLatentModule
 | 
				
			||||||
 | 
					    latent_shape: Tuple[int, int]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_shape_model_only(self):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode_image_embed(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode_text_embed(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode_shape_embed(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TexturedShapeAsLatentModule(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decode(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def query_geometry(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def query_color(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
							
								
								
									
										42
									
								
								primitive_anything/michelangelo/shapevae-256.yaml
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										42
									
								
								primitive_anything/michelangelo/shapevae-256.yaml
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,42 @@
 | 
				
			|||||||
 | 
					model:
 | 
				
			||||||
 | 
					  target: primitive_anything.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
 | 
				
			||||||
 | 
					  params:
 | 
				
			||||||
 | 
					    shape_module_cfg:
 | 
				
			||||||
 | 
					      target: primitive_anything.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        num_latents: 256
 | 
				
			||||||
 | 
					        embed_dim: 64
 | 
				
			||||||
 | 
					        point_feats: 3   # normal
 | 
				
			||||||
 | 
					        num_freqs: 8
 | 
				
			||||||
 | 
					        include_pi: false
 | 
				
			||||||
 | 
					        heads: 12
 | 
				
			||||||
 | 
					        width: 768
 | 
				
			||||||
 | 
					        num_encoder_layers: 8
 | 
				
			||||||
 | 
					        num_decoder_layers: 16
 | 
				
			||||||
 | 
					        use_ln_post: true
 | 
				
			||||||
 | 
					        init_scale: 0.25
 | 
				
			||||||
 | 
					        qkv_bias: false
 | 
				
			||||||
 | 
					        use_checkpoint: true
 | 
				
			||||||
 | 
					    aligned_module_cfg:
 | 
				
			||||||
 | 
					      target: primitive_anything.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
 | 
				
			||||||
 | 
					    loss_cfg:
 | 
				
			||||||
 | 
					      target: primitive_anything.michelangelo.models.tsal.loss.ContrastKLNearFar
 | 
				
			||||||
 | 
					      params:
 | 
				
			||||||
 | 
					        contrast_weight: 0.1
 | 
				
			||||||
 | 
					        near_weight: 0.1
 | 
				
			||||||
 | 
					        kl_weight: 0.001
 | 
				
			||||||
 | 
					    optimizer_cfg:
 | 
				
			||||||
 | 
					      optimizer:
 | 
				
			||||||
 | 
					        target: torch.optim.AdamW
 | 
				
			||||||
 | 
					        params:
 | 
				
			||||||
 | 
					          betas: [0.9, 0.99]
 | 
				
			||||||
 | 
					          eps: 1.e-6
 | 
				
			||||||
 | 
					          weight_decay: 1.e-2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      scheduler:
 | 
				
			||||||
 | 
					        target: primitive_anything.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
 | 
				
			||||||
 | 
					        params:
 | 
				
			||||||
 | 
					          warm_up_steps: 5000
 | 
				
			||||||
 | 
					          f_start: 1.e-6
 | 
				
			||||||
 | 
					          f_min: 1.e-3
 | 
				
			||||||
 | 
					          f_max: 1.0
 | 
				
			||||||
							
								
								
									
										4
									
								
								primitive_anything/michelangelo/utils/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										4
									
								
								primitive_anything/michelangelo/utils/__init__.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,4 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .misc import get_config_from_file
 | 
				
			||||||
 | 
					from .misc import instantiate_from_config
 | 
				
			||||||
							
								
								
									
										12
									
								
								primitive_anything/michelangelo/utils/eval.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										12
									
								
								primitive_anything/michelangelo/utils/eval.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,12 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mse = torch.mean((x - y) ** 2)
 | 
				
			||||||
 | 
					    psnr = 10 * torch.log10(data_range / (mse + eps))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return psnr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										47
									
								
								primitive_anything/michelangelo/utils/io.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										47
									
								
								primitive_anything/michelangelo/utils/io.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,47 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import io
 | 
				
			||||||
 | 
					import tarfile
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import numpy.lib.format
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def mkdir(path):
 | 
				
			||||||
 | 
					    os.makedirs(path, exist_ok=True)
 | 
				
			||||||
 | 
					    return path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def npy_loads(data):
 | 
				
			||||||
 | 
					    stream = io.BytesIO(data)
 | 
				
			||||||
 | 
					    return np.lib.format.read_array(stream)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def npz_loads(data):
 | 
				
			||||||
 | 
					    return np.load(io.BytesIO(data))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def json_loads(data):
 | 
				
			||||||
 | 
					    return json.loads(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_json(filepath):
 | 
				
			||||||
 | 
					    with open(filepath, "r") as f:
 | 
				
			||||||
 | 
					        data = json.load(f)
 | 
				
			||||||
 | 
					        return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def write_json(filepath, data):
 | 
				
			||||||
 | 
					    with open(filepath, "w") as f:
 | 
				
			||||||
 | 
					        json.dump(data, f, indent=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def extract_tar(tar_path, tar_cache_folder):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with tarfile.open(tar_path, "r") as tar:
 | 
				
			||||||
 | 
					        tar.extractall(path=tar_cache_folder)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tar_uids = sorted(os.listdir(tar_cache_folder))
 | 
				
			||||||
 | 
					    print(f"extract tar: {tar_path} to {tar_cache_folder}")
 | 
				
			||||||
 | 
					    return tar_uids
 | 
				
			||||||
							
								
								
									
										103
									
								
								primitive_anything/michelangelo/utils/misc.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										103
									
								
								primitive_anything/michelangelo/utils/misc.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,103 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import importlib
 | 
				
			||||||
 | 
					from omegaconf import OmegaConf, DictConfig, ListConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.distributed as dist
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]:
 | 
				
			||||||
 | 
					    config_file = OmegaConf.load(config_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if 'base_config' in config_file.keys():
 | 
				
			||||||
 | 
					        if config_file['base_config'] == "default_base":
 | 
				
			||||||
 | 
					            base_config = OmegaConf.create()
 | 
				
			||||||
 | 
					            # base_config = get_default_config()
 | 
				
			||||||
 | 
					        elif config_file['base_config'].endswith(".yaml"):
 | 
				
			||||||
 | 
					            base_config = get_config_from_file(config_file['base_config'])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        config_file = {key: value for key, value in config_file if key != "base_config"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return OmegaConf.merge(base_config, config_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return config_file
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_obj_from_str(string, reload=False):
 | 
				
			||||||
 | 
					    module, cls = string.rsplit(".", 1)
 | 
				
			||||||
 | 
					    if reload:
 | 
				
			||||||
 | 
					        module_imp = importlib.import_module(module)
 | 
				
			||||||
 | 
					        importlib.reload(module_imp)
 | 
				
			||||||
 | 
					    return getattr(importlib.import_module(module, package=None), cls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_obj_from_config(config):
 | 
				
			||||||
 | 
					    if "target" not in config:
 | 
				
			||||||
 | 
					        raise KeyError("Expected key `target` to instantiate.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return get_obj_from_str(config["target"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def instantiate_from_config(config, **kwargs):
 | 
				
			||||||
 | 
					    if "target" not in config:
 | 
				
			||||||
 | 
					        raise KeyError("Expected key `target` to instantiate.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cls = get_obj_from_str(config["target"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    params = config.get("params", dict())
 | 
				
			||||||
 | 
					    # params.update(kwargs)
 | 
				
			||||||
 | 
					    # instance = cls(**params)
 | 
				
			||||||
 | 
					    kwargs.update(params)
 | 
				
			||||||
 | 
					    instance = cls(**kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return instance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_dist_avail_and_initialized():
 | 
				
			||||||
 | 
					    if not dist.is_available():
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    if not dist.is_initialized():
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_rank():
 | 
				
			||||||
 | 
					    if not is_dist_avail_and_initialized():
 | 
				
			||||||
 | 
					        return 0
 | 
				
			||||||
 | 
					    return dist.get_rank()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_world_size():
 | 
				
			||||||
 | 
					    if not is_dist_avail_and_initialized():
 | 
				
			||||||
 | 
					        return 1
 | 
				
			||||||
 | 
					    return dist.get_world_size()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def all_gather_batch(tensors):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Performs all_gather operation on the provided tensors.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Queue the gathered tensors
 | 
				
			||||||
 | 
					    world_size = get_world_size()
 | 
				
			||||||
 | 
					    # There is no need for reduction in the single-proc case
 | 
				
			||||||
 | 
					    if world_size == 1:
 | 
				
			||||||
 | 
					        return tensors
 | 
				
			||||||
 | 
					    tensor_list = []
 | 
				
			||||||
 | 
					    output_tensor = []
 | 
				
			||||||
 | 
					    for tensor in tensors:
 | 
				
			||||||
 | 
					        tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
 | 
				
			||||||
 | 
					        dist.all_gather(
 | 
				
			||||||
 | 
					            tensor_all,
 | 
				
			||||||
 | 
					            tensor,
 | 
				
			||||||
 | 
					            async_op=False  # performance opt
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tensor_list.append(tensor_all)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for tensor_all in tensor_list:
 | 
				
			||||||
 | 
					        output_tensor.append(torch.cat(tensor_all, dim=0))
 | 
				
			||||||
 | 
					    return output_tensor
 | 
				
			||||||
							
								
								
									
										1
									
								
								primitive_anything/michelangelo/utils/visualizers/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										1
									
								
								primitive_anything/michelangelo/utils/visualizers/__init__.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
							
								
								
									
										43
									
								
								primitive_anything/michelangelo/utils/visualizers/color_util.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										43
									
								
								primitive_anything/michelangelo/utils/visualizers/color_util.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,43 @@
 | 
				
			|||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Helper functions
 | 
				
			||||||
 | 
					def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None):
 | 
				
			||||||
 | 
					    colormap = plt.cm.get_cmap(colormap)
 | 
				
			||||||
 | 
					    if normalize:
 | 
				
			||||||
 | 
					        vmin = np.min(inp)
 | 
				
			||||||
 | 
					        vmax = np.max(inp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    norm = plt.Normalize(vmin, vmax)
 | 
				
			||||||
 | 
					    return colormap(norm(inp))[:, :3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):
 | 
				
			||||||
 | 
					    # tex dims need to be power of two.
 | 
				
			||||||
 | 
					    array = np.ones((width, height, 3), dtype='float32')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # width in texels of each checker
 | 
				
			||||||
 | 
					    checker_w = width / n_checkers_x
 | 
				
			||||||
 | 
					    checker_h = height / n_checkers_y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for y in range(height):
 | 
				
			||||||
 | 
					        for x in range(width):
 | 
				
			||||||
 | 
					            color_key = int(x / checker_w) + int(y / checker_h)
 | 
				
			||||||
 | 
					            if color_key % 2 == 0:
 | 
				
			||||||
 | 
					                array[x, y, :] = [1., 0.874, 0.0]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                array[x, y, :] = [0., 0., 0.]
 | 
				
			||||||
 | 
					    return array
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gen_circle(width=256, height=256):
 | 
				
			||||||
 | 
					    xx, yy = np.mgrid[:width, :height]
 | 
				
			||||||
 | 
					    circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2
 | 
				
			||||||
 | 
					    array = np.ones((width, height, 4), dtype='float32')
 | 
				
			||||||
 | 
					    array[:, :, 0] = (circle <= width)
 | 
				
			||||||
 | 
					    array[:, :, 1] = (circle <= width)
 | 
				
			||||||
 | 
					    array[:, :, 2] = (circle <= width)
 | 
				
			||||||
 | 
					    array[:, :, 3] = circle <= width
 | 
				
			||||||
 | 
					    return array
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										49
									
								
								primitive_anything/michelangelo/utils/visualizers/html_util.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										49
									
								
								primitive_anything/michelangelo/utils/visualizers/html_util.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,49 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					import io
 | 
				
			||||||
 | 
					import base64
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def to_html_frame(content):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    html_frame = f"""
 | 
				
			||||||
 | 
					    <html>
 | 
				
			||||||
 | 
					      <body>
 | 
				
			||||||
 | 
					        {content}
 | 
				
			||||||
 | 
					      </body>
 | 
				
			||||||
 | 
					    </html>
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return html_frame
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def to_single_row_table(caption: str, content: str):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    table_html = f"""
 | 
				
			||||||
 | 
					    <table border = "1">
 | 
				
			||||||
 | 
					        <caption>{caption}</caption>
 | 
				
			||||||
 | 
					        <tr>
 | 
				
			||||||
 | 
					            <td>{content}</td>
 | 
				
			||||||
 | 
					        </tr>
 | 
				
			||||||
 | 
					    </table>
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return table_html
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def to_image_embed_tag(image: np.ndarray):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Convert np.ndarray to bytes
 | 
				
			||||||
 | 
					    img = Image.fromarray(image)
 | 
				
			||||||
 | 
					    raw_bytes = io.BytesIO()
 | 
				
			||||||
 | 
					    img.save(raw_bytes, "PNG")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Encode bytes to base64
 | 
				
			||||||
 | 
					    image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    image_tag = f"""
 | 
				
			||||||
 | 
					    <img src="data:image/png;base64,{image_base64}" alt="Embedded Image">
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return image_tag
 | 
				
			||||||
							
								
								
									
										534
									
								
								primitive_anything/michelangelo/utils/visualizers/pythreejs_viewer.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										534
									
								
								primitive_anything/michelangelo/utils/visualizers/pythreejs_viewer.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,534 @@
 | 
				
			|||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from ipywidgets import embed
 | 
				
			||||||
 | 
					import pythreejs as p3s
 | 
				
			||||||
 | 
					import uuid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .color_util import get_colors, gen_circle, gen_checkers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					EMBED_URL = "https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@1.0.1/dist/embed-amd.js"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PyThreeJSViewer(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, settings, render_mode="WEBSITE"):
 | 
				
			||||||
 | 
					        self.render_mode = render_mode
 | 
				
			||||||
 | 
					        self.__update_settings(settings)
 | 
				
			||||||
 | 
					        self._light = p3s.DirectionalLight(color='white', position=[0, 0, 1], intensity=0.6)
 | 
				
			||||||
 | 
					        self._light2 = p3s.AmbientLight(intensity=0.5)
 | 
				
			||||||
 | 
					        self._cam = p3s.PerspectiveCamera(position=[0, 0, 1], lookAt=[0, 0, 0], fov=self.__s["fov"],
 | 
				
			||||||
 | 
					                                          aspect=self.__s["width"] / self.__s["height"], children=[self._light])
 | 
				
			||||||
 | 
					        self._orbit = p3s.OrbitControls(controlling=self._cam)
 | 
				
			||||||
 | 
					        self._scene = p3s.Scene(children=[self._cam, self._light2], background=self.__s["background"])  # "#4c4c80"
 | 
				
			||||||
 | 
					        self._renderer = p3s.Renderer(camera=self._cam, scene=self._scene, controls=[self._orbit],
 | 
				
			||||||
 | 
					                                      width=self.__s["width"], height=self.__s["height"],
 | 
				
			||||||
 | 
					                                      antialias=self.__s["antialias"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.__objects = {}
 | 
				
			||||||
 | 
					        self.__cnt = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def jupyter_mode(self):
 | 
				
			||||||
 | 
					        self.render_mode = "JUPYTER"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def offline(self):
 | 
				
			||||||
 | 
					        self.render_mode = "OFFLINE"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def website(self):
 | 
				
			||||||
 | 
					        self.render_mode = "WEBSITE"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __get_shading(self, shading):
 | 
				
			||||||
 | 
					        shad = {"flat": True, "wireframe": False, "wire_width": 0.03, "wire_color": "black",
 | 
				
			||||||
 | 
					                "side": 'DoubleSide', "colormap": "viridis", "normalize": [None, None],
 | 
				
			||||||
 | 
					                "bbox": False, "roughness": 0.5, "metalness": 0.25, "reflectivity": 1.0,
 | 
				
			||||||
 | 
					                "line_width": 1.0, "line_color": "black",
 | 
				
			||||||
 | 
					                "point_color": "red", "point_size": 0.01, "point_shape": "circle",
 | 
				
			||||||
 | 
					                "text_color": "red"
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					        for k in shading:
 | 
				
			||||||
 | 
					            shad[k] = shading[k]
 | 
				
			||||||
 | 
					        return shad
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __update_settings(self, settings={}):
 | 
				
			||||||
 | 
					        sett = {"width": 600, "height": 600, "antialias": True, "scale": 1.5, "background": "#ffffff",
 | 
				
			||||||
 | 
					                "fov": 30}
 | 
				
			||||||
 | 
					        for k in settings:
 | 
				
			||||||
 | 
					            sett[k] = settings[k]
 | 
				
			||||||
 | 
					        self.__s = sett
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __add_object(self, obj, parent=None):
 | 
				
			||||||
 | 
					        if not parent:  # Object is added to global scene and objects dict
 | 
				
			||||||
 | 
					            self.__objects[self.__cnt] = obj
 | 
				
			||||||
 | 
					            self.__cnt += 1
 | 
				
			||||||
 | 
					            self._scene.add(obj["mesh"])
 | 
				
			||||||
 | 
					        else:  # Object is added to parent object and NOT to objects dict
 | 
				
			||||||
 | 
					            parent.add(obj["mesh"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.__update_view()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.render_mode == "JUPYTER":
 | 
				
			||||||
 | 
					            return self.__cnt - 1
 | 
				
			||||||
 | 
					        elif self.render_mode == "WEBSITE":
 | 
				
			||||||
 | 
					            return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __add_line_geometry(self, lines, shading, obj=None):
 | 
				
			||||||
 | 
					        lines = lines.astype("float32", copy=False)
 | 
				
			||||||
 | 
					        mi = np.min(lines, axis=0)
 | 
				
			||||||
 | 
					        ma = np.max(lines, axis=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        geometry = p3s.LineSegmentsGeometry(positions=lines.reshape((-1, 2, 3)))
 | 
				
			||||||
 | 
					        material = p3s.LineMaterial(linewidth=shading["line_width"], color=shading["line_color"])
 | 
				
			||||||
 | 
					        # , vertexColors='VertexColors'),
 | 
				
			||||||
 | 
					        lines = p3s.LineSegments2(geometry=geometry, material=material)  # type='LinePieces')
 | 
				
			||||||
 | 
					        line_obj = {"geometry": geometry, "mesh": lines, "material": material,
 | 
				
			||||||
 | 
					                    "max": ma, "min": mi, "type": "Lines", "wireframe": None}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if obj:
 | 
				
			||||||
 | 
					            return self.__add_object(line_obj, obj), line_obj
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self.__add_object(line_obj)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __update_view(self):
 | 
				
			||||||
 | 
					        if len(self.__objects) == 0:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        ma = np.zeros((len(self.__objects), 3))
 | 
				
			||||||
 | 
					        mi = np.zeros((len(self.__objects), 3))
 | 
				
			||||||
 | 
					        for r, obj in enumerate(self.__objects):
 | 
				
			||||||
 | 
					            ma[r] = self.__objects[obj]["max"]
 | 
				
			||||||
 | 
					            mi[r] = self.__objects[obj]["min"]
 | 
				
			||||||
 | 
					        ma = np.max(ma, axis=0)
 | 
				
			||||||
 | 
					        mi = np.min(mi, axis=0)
 | 
				
			||||||
 | 
					        diag = np.linalg.norm(ma - mi)
 | 
				
			||||||
 | 
					        mean = ((ma - mi) / 2 + mi).tolist()
 | 
				
			||||||
 | 
					        scale = self.__s["scale"] * (diag)
 | 
				
			||||||
 | 
					        self._orbit.target = mean
 | 
				
			||||||
 | 
					        self._cam.lookAt(mean)
 | 
				
			||||||
 | 
					        self._cam.position = [mean[0], mean[1], mean[2] + scale]
 | 
				
			||||||
 | 
					        self._light.position = [mean[0], mean[1], mean[2] + scale]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._orbit.exec_three_obj_method('update')
 | 
				
			||||||
 | 
					        self._cam.exec_three_obj_method('updateProjectionMatrix')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __get_bbox(self, v):
 | 
				
			||||||
 | 
					        m = np.min(v, axis=0)
 | 
				
			||||||
 | 
					        M = np.max(v, axis=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Corners of the bounding box
 | 
				
			||||||
 | 
					        v_box = np.array([[m[0], m[1], m[2]], [M[0], m[1], m[2]], [M[0], M[1], m[2]], [m[0], M[1], m[2]],
 | 
				
			||||||
 | 
					                          [m[0], m[1], M[2]], [M[0], m[1], M[2]], [M[0], M[1], M[2]], [m[0], M[1], M[2]]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4],
 | 
				
			||||||
 | 
					                          [0, 4], [1, 5], [2, 6], [7, 3]], dtype=np.uint32)
 | 
				
			||||||
 | 
					        return v_box, f_box
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __get_colors(self, v, f, c, sh):
 | 
				
			||||||
 | 
					        coloring = "VertexColors"
 | 
				
			||||||
 | 
					        if type(c) == np.ndarray and c.size == 3:  # Single color
 | 
				
			||||||
 | 
					            colors = np.ones_like(v)
 | 
				
			||||||
 | 
					            colors[:, 0] = c[0]
 | 
				
			||||||
 | 
					            colors[:, 1] = c[1]
 | 
				
			||||||
 | 
					            colors[:, 2] = c[2]
 | 
				
			||||||
 | 
					            # print("Single colors")
 | 
				
			||||||
 | 
					        elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[1] == 3:  # Color values for
 | 
				
			||||||
 | 
					            if c.shape[0] == f.shape[0]:  # faces
 | 
				
			||||||
 | 
					                colors = np.hstack([c, c, c]).reshape((-1, 3))
 | 
				
			||||||
 | 
					                coloring = "FaceColors"
 | 
				
			||||||
 | 
					                # print("Face color values")
 | 
				
			||||||
 | 
					            elif c.shape[0] == v.shape[0]:  # vertices
 | 
				
			||||||
 | 
					                colors = c
 | 
				
			||||||
 | 
					                # print("Vertex color values")
 | 
				
			||||||
 | 
					            else:  # Wrong size, fallback
 | 
				
			||||||
 | 
					                print("Invalid color array given! Supported are numpy arrays.", type(c))
 | 
				
			||||||
 | 
					                colors = np.ones_like(v)
 | 
				
			||||||
 | 
					                colors[:, 0] = 1.0
 | 
				
			||||||
 | 
					                colors[:, 1] = 0.874
 | 
				
			||||||
 | 
					                colors[:, 2] = 0.0
 | 
				
			||||||
 | 
					        elif type(c) == np.ndarray and c.size == f.shape[0]:  # Function values for faces
 | 
				
			||||||
 | 
					            normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
 | 
				
			||||||
 | 
					            cc = get_colors(c, sh["colormap"], normalize=normalize,
 | 
				
			||||||
 | 
					                            vmin=sh["normalize"][0], vmax=sh["normalize"][1])
 | 
				
			||||||
 | 
					            # print(cc.shape)
 | 
				
			||||||
 | 
					            colors = np.hstack([cc, cc, cc]).reshape((-1, 3))
 | 
				
			||||||
 | 
					            coloring = "FaceColors"
 | 
				
			||||||
 | 
					            # print("Face function values")
 | 
				
			||||||
 | 
					        elif type(c) == np.ndarray and c.size == v.shape[0]:  # Function values for vertices
 | 
				
			||||||
 | 
					            normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
 | 
				
			||||||
 | 
					            colors = get_colors(c, sh["colormap"], normalize=normalize,
 | 
				
			||||||
 | 
					                                vmin=sh["normalize"][0], vmax=sh["normalize"][1])
 | 
				
			||||||
 | 
					            # print("Vertex function values")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            colors = np.ones_like(v)
 | 
				
			||||||
 | 
					            colors[:, 0] = 1.0
 | 
				
			||||||
 | 
					            colors[:, 1] = 0.874
 | 
				
			||||||
 | 
					            colors[:, 2] = 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # No color
 | 
				
			||||||
 | 
					            if c is not None:
 | 
				
			||||||
 | 
					                print("Invalid color array given! Supported are numpy arrays.", type(c))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return colors, coloring
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __get_point_colors(self, v, c, sh):
 | 
				
			||||||
 | 
					        v_color = True
 | 
				
			||||||
 | 
					        if c is None:  # No color given, use global color
 | 
				
			||||||
 | 
					            # conv = mpl.colors.ColorConverter()
 | 
				
			||||||
 | 
					            colors = sh["point_color"]  # np.array(conv.to_rgb(sh["point_color"]))
 | 
				
			||||||
 | 
					            v_color = False
 | 
				
			||||||
 | 
					        elif isinstance(c, str):  # No color given, use global color
 | 
				
			||||||
 | 
					            # conv = mpl.colors.ColorConverter()
 | 
				
			||||||
 | 
					            colors = c  # np.array(conv.to_rgb(c))
 | 
				
			||||||
 | 
					            v_color = False
 | 
				
			||||||
 | 
					        elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] == 3:
 | 
				
			||||||
 | 
					            # Point color
 | 
				
			||||||
 | 
					            colors = c.astype("float32", copy=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif isinstance(c, np.ndarray) and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] != 3:
 | 
				
			||||||
 | 
					            # Function values for vertices, but the colors are features
 | 
				
			||||||
 | 
					            c_norm = np.linalg.norm(c, ord=2, axis=-1)
 | 
				
			||||||
 | 
					            normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
 | 
				
			||||||
 | 
					            colors = get_colors(c_norm, sh["colormap"], normalize=normalize,
 | 
				
			||||||
 | 
					                                vmin=sh["normalize"][0], vmax=sh["normalize"][1])
 | 
				
			||||||
 | 
					            colors = colors.astype("float32", copy=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif type(c) == np.ndarray and c.size == v.shape[0]:  # Function color
 | 
				
			||||||
 | 
					            normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
 | 
				
			||||||
 | 
					            colors = get_colors(c, sh["colormap"], normalize=normalize,
 | 
				
			||||||
 | 
					                                vmin=sh["normalize"][0], vmax=sh["normalize"][1])
 | 
				
			||||||
 | 
					            colors = colors.astype("float32", copy=False)
 | 
				
			||||||
 | 
					            # print("Vertex function values")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            print("Invalid color array given! Supported are numpy arrays.", type(c))
 | 
				
			||||||
 | 
					            colors = sh["point_color"]
 | 
				
			||||||
 | 
					            v_color = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return colors, v_color
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_mesh(self, v, f, c=None, uv=None, n=None, shading={}, texture_data=None, **kwargs):
 | 
				
			||||||
 | 
					        shading.update(kwargs)
 | 
				
			||||||
 | 
					        sh = self.__get_shading(shading)
 | 
				
			||||||
 | 
					        mesh_obj = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # it is a tet
 | 
				
			||||||
 | 
					        if v.shape[1] == 3 and f.shape[1] == 4:
 | 
				
			||||||
 | 
					            f_tmp = np.ndarray([f.shape[0] * 4, 3], dtype=f.dtype)
 | 
				
			||||||
 | 
					            for i in range(f.shape[0]):
 | 
				
			||||||
 | 
					                f_tmp[i * 4 + 0] = np.array([f[i][1], f[i][0], f[i][2]])
 | 
				
			||||||
 | 
					                f_tmp[i * 4 + 1] = np.array([f[i][0], f[i][1], f[i][3]])
 | 
				
			||||||
 | 
					                f_tmp[i * 4 + 2] = np.array([f[i][1], f[i][2], f[i][3]])
 | 
				
			||||||
 | 
					                f_tmp[i * 4 + 3] = np.array([f[i][2], f[i][0], f[i][3]])
 | 
				
			||||||
 | 
					            f = f_tmp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if v.shape[1] == 2:
 | 
				
			||||||
 | 
					            v = np.append(v, np.zeros([v.shape[0], 1]), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Type adjustment vertices
 | 
				
			||||||
 | 
					        v = v.astype("float32", copy=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Color setup
 | 
				
			||||||
 | 
					        colors, coloring = self.__get_colors(v, f, c, sh)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Type adjustment faces and colors
 | 
				
			||||||
 | 
					        c = colors.astype("float32", copy=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Material and geometry setup
 | 
				
			||||||
 | 
					        ba_dict = {"color": p3s.BufferAttribute(c)}
 | 
				
			||||||
 | 
					        if coloring == "FaceColors":
 | 
				
			||||||
 | 
					            verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
 | 
				
			||||||
 | 
					            for ii in range(f.shape[0]):
 | 
				
			||||||
 | 
					                # print(ii*3, f[ii])
 | 
				
			||||||
 | 
					                verts[ii * 3] = v[f[ii, 0]]
 | 
				
			||||||
 | 
					                verts[ii * 3 + 1] = v[f[ii, 1]]
 | 
				
			||||||
 | 
					                verts[ii * 3 + 2] = v[f[ii, 2]]
 | 
				
			||||||
 | 
					            v = verts
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            f = f.astype("uint32", copy=False).ravel()
 | 
				
			||||||
 | 
					            ba_dict["index"] = p3s.BufferAttribute(f, normalized=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ba_dict["position"] = p3s.BufferAttribute(v, normalized=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if uv is not None:
 | 
				
			||||||
 | 
					            uv = (uv - np.min(uv)) / (np.max(uv) - np.min(uv))
 | 
				
			||||||
 | 
					            if texture_data is None:
 | 
				
			||||||
 | 
					                texture_data = gen_checkers(20, 20)
 | 
				
			||||||
 | 
					            tex = p3s.DataTexture(data=texture_data, format="RGBFormat", type="FloatType")
 | 
				
			||||||
 | 
					            material = p3s.MeshStandardMaterial(map=tex, reflectivity=sh["reflectivity"], side=sh["side"],
 | 
				
			||||||
 | 
					                                                roughness=sh["roughness"], metalness=sh["metalness"],
 | 
				
			||||||
 | 
					                                                flatShading=sh["flat"],
 | 
				
			||||||
 | 
					                                                polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
 | 
				
			||||||
 | 
					            ba_dict["uv"] = p3s.BufferAttribute(uv.astype("float32", copy=False))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            material = p3s.MeshStandardMaterial(vertexColors=coloring, reflectivity=sh["reflectivity"],
 | 
				
			||||||
 | 
					                                                side=sh["side"], roughness=sh["roughness"], metalness=sh["metalness"],
 | 
				
			||||||
 | 
					                                                flatShading=sh["flat"],
 | 
				
			||||||
 | 
					                                                polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if type(n) != type(None) and coloring == "VertexColors":  # TODO: properly handle normals for FaceColors as well
 | 
				
			||||||
 | 
					            ba_dict["normal"] = p3s.BufferAttribute(n.astype("float32", copy=False), normalized=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        geometry = p3s.BufferGeometry(attributes=ba_dict)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if coloring == "VertexColors" and type(n) == type(None):
 | 
				
			||||||
 | 
					            geometry.exec_three_obj_method('computeVertexNormals')
 | 
				
			||||||
 | 
					        elif coloring == "FaceColors" and type(n) == type(None):
 | 
				
			||||||
 | 
					            geometry.exec_three_obj_method('computeFaceNormals')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Mesh setup
 | 
				
			||||||
 | 
					        mesh = p3s.Mesh(geometry=geometry, material=material)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Wireframe setup
 | 
				
			||||||
 | 
					        mesh_obj["wireframe"] = None
 | 
				
			||||||
 | 
					        if sh["wireframe"]:
 | 
				
			||||||
 | 
					            wf_geometry = p3s.WireframeGeometry(mesh.geometry)  # WireframeGeometry
 | 
				
			||||||
 | 
					            wf_material = p3s.LineBasicMaterial(color=sh["wire_color"], linewidth=sh["wire_width"])
 | 
				
			||||||
 | 
					            wireframe = p3s.LineSegments(wf_geometry, wf_material)
 | 
				
			||||||
 | 
					            mesh.add(wireframe)
 | 
				
			||||||
 | 
					            mesh_obj["wireframe"] = wireframe
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Bounding box setup
 | 
				
			||||||
 | 
					        if sh["bbox"]:
 | 
				
			||||||
 | 
					            v_box, f_box = self.__get_bbox(v)
 | 
				
			||||||
 | 
					            _, bbox = self.add_edges(v_box, f_box, sh, mesh)
 | 
				
			||||||
 | 
					            mesh_obj["bbox"] = [bbox, v_box, f_box]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Object setup
 | 
				
			||||||
 | 
					        mesh_obj["max"] = np.max(v, axis=0)
 | 
				
			||||||
 | 
					        mesh_obj["min"] = np.min(v, axis=0)
 | 
				
			||||||
 | 
					        mesh_obj["geometry"] = geometry
 | 
				
			||||||
 | 
					        mesh_obj["mesh"] = mesh
 | 
				
			||||||
 | 
					        mesh_obj["material"] = material
 | 
				
			||||||
 | 
					        mesh_obj["type"] = "Mesh"
 | 
				
			||||||
 | 
					        mesh_obj["shading"] = sh
 | 
				
			||||||
 | 
					        mesh_obj["coloring"] = coloring
 | 
				
			||||||
 | 
					        mesh_obj["arrays"] = [v, f, c]  # TODO replays with proper storage or remove if not needed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.__add_object(mesh_obj)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_lines(self, beginning, ending, shading={}, obj=None, **kwargs):
 | 
				
			||||||
 | 
					        shading.update(kwargs)
 | 
				
			||||||
 | 
					        if len(beginning.shape) == 1:
 | 
				
			||||||
 | 
					            if len(beginning) == 2:
 | 
				
			||||||
 | 
					                beginning = np.array([[beginning[0], beginning[1], 0]])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if beginning.shape[1] == 2:
 | 
				
			||||||
 | 
					                beginning = np.append(
 | 
				
			||||||
 | 
					                    beginning, np.zeros([beginning.shape[0], 1]), 1)
 | 
				
			||||||
 | 
					        if len(ending.shape) == 1:
 | 
				
			||||||
 | 
					            if len(ending) == 2:
 | 
				
			||||||
 | 
					                ending = np.array([[ending[0], ending[1], 0]])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if ending.shape[1] == 2:
 | 
				
			||||||
 | 
					                ending = np.append(
 | 
				
			||||||
 | 
					                    ending, np.zeros([ending.shape[0], 1]), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sh = self.__get_shading(shading)
 | 
				
			||||||
 | 
					        lines = np.hstack([beginning, ending])
 | 
				
			||||||
 | 
					        lines = lines.reshape((-1, 3))
 | 
				
			||||||
 | 
					        return self.__add_line_geometry(lines, sh, obj)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_edges(self, vertices, edges, shading={}, obj=None, **kwargs):
 | 
				
			||||||
 | 
					        shading.update(kwargs)
 | 
				
			||||||
 | 
					        if vertices.shape[1] == 2:
 | 
				
			||||||
 | 
					            vertices = np.append(
 | 
				
			||||||
 | 
					                vertices, np.zeros([vertices.shape[0], 1]), 1)
 | 
				
			||||||
 | 
					        sh = self.__get_shading(shading)
 | 
				
			||||||
 | 
					        lines = np.zeros((edges.size, 3))
 | 
				
			||||||
 | 
					        cnt = 0
 | 
				
			||||||
 | 
					        for e in edges:
 | 
				
			||||||
 | 
					            lines[cnt, :] = vertices[e[0]]
 | 
				
			||||||
 | 
					            lines[cnt + 1, :] = vertices[e[1]]
 | 
				
			||||||
 | 
					            cnt += 2
 | 
				
			||||||
 | 
					        return self.__add_line_geometry(lines, sh, obj)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_points(self, points, c=None, shading={}, obj=None, **kwargs):
 | 
				
			||||||
 | 
					        shading.update(kwargs)
 | 
				
			||||||
 | 
					        if len(points.shape) == 1:
 | 
				
			||||||
 | 
					            if len(points) == 2:
 | 
				
			||||||
 | 
					                points = np.array([[points[0], points[1], 0]])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if points.shape[1] == 2:
 | 
				
			||||||
 | 
					                points = np.append(
 | 
				
			||||||
 | 
					                    points, np.zeros([points.shape[0], 1]), 1)
 | 
				
			||||||
 | 
					        sh = self.__get_shading(shading)
 | 
				
			||||||
 | 
					        points = points.astype("float32", copy=False)
 | 
				
			||||||
 | 
					        mi = np.min(points, axis=0)
 | 
				
			||||||
 | 
					        ma = np.max(points, axis=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        g_attributes = {"position": p3s.BufferAttribute(points, normalized=False)}
 | 
				
			||||||
 | 
					        m_attributes = {"size": sh["point_size"]}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if sh["point_shape"] == "circle":  # Plot circles
 | 
				
			||||||
 | 
					            tex = p3s.DataTexture(data=gen_circle(16, 16), format="RGBAFormat", type="FloatType")
 | 
				
			||||||
 | 
					            m_attributes["map"] = tex
 | 
				
			||||||
 | 
					            m_attributes["alphaTest"] = 0.5
 | 
				
			||||||
 | 
					            m_attributes["transparency"] = True
 | 
				
			||||||
 | 
					        else:  # Plot squares
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        colors, v_colors = self.__get_point_colors(points, c, sh)
 | 
				
			||||||
 | 
					        if v_colors:  # Colors per point
 | 
				
			||||||
 | 
					            m_attributes["vertexColors"] = 'VertexColors'
 | 
				
			||||||
 | 
					            g_attributes["color"] = p3s.BufferAttribute(colors, normalized=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        else:  # Colors for all points
 | 
				
			||||||
 | 
					            m_attributes["color"] = colors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        material = p3s.PointsMaterial(**m_attributes)
 | 
				
			||||||
 | 
					        geometry = p3s.BufferGeometry(attributes=g_attributes)
 | 
				
			||||||
 | 
					        points = p3s.Points(geometry=geometry, material=material)
 | 
				
			||||||
 | 
					        point_obj = {"geometry": geometry, "mesh": points, "material": material,
 | 
				
			||||||
 | 
					                     "max": ma, "min": mi, "type": "Points", "wireframe": None}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if obj:
 | 
				
			||||||
 | 
					            return self.__add_object(point_obj, obj), point_obj
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self.__add_object(point_obj)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_object(self, obj_id):
 | 
				
			||||||
 | 
					        if obj_id not in self.__objects:
 | 
				
			||||||
 | 
					            print("Invalid object id. Valid ids are: ", list(self.__objects.keys()))
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        self._scene.remove(self.__objects[obj_id]["mesh"])
 | 
				
			||||||
 | 
					        del self.__objects[obj_id]
 | 
				
			||||||
 | 
					        self.__update_view()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def reset(self):
 | 
				
			||||||
 | 
					        for obj_id in list(self.__objects.keys()).copy():
 | 
				
			||||||
 | 
					            self._scene.remove(self.__objects[obj_id]["mesh"])
 | 
				
			||||||
 | 
					            del self.__objects[obj_id]
 | 
				
			||||||
 | 
					        self.__update_view()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update_object(self, oid=0, vertices=None, colors=None, faces=None):
 | 
				
			||||||
 | 
					        obj = self.__objects[oid]
 | 
				
			||||||
 | 
					        if type(vertices) != type(None):
 | 
				
			||||||
 | 
					            if obj["coloring"] == "FaceColors":
 | 
				
			||||||
 | 
					                f = obj["arrays"][1]
 | 
				
			||||||
 | 
					                verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
 | 
				
			||||||
 | 
					                for ii in range(f.shape[0]):
 | 
				
			||||||
 | 
					                    # print(ii*3, f[ii])
 | 
				
			||||||
 | 
					                    verts[ii * 3] = vertices[f[ii, 0]]
 | 
				
			||||||
 | 
					                    verts[ii * 3 + 1] = vertices[f[ii, 1]]
 | 
				
			||||||
 | 
					                    verts[ii * 3 + 2] = vertices[f[ii, 2]]
 | 
				
			||||||
 | 
					                v = verts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                v = vertices.astype("float32", copy=False)
 | 
				
			||||||
 | 
					            obj["geometry"].attributes["position"].array = v
 | 
				
			||||||
 | 
					            # self.wireframe.attributes["position"].array = v # Wireframe updates?
 | 
				
			||||||
 | 
					            obj["geometry"].attributes["position"].needsUpdate = True
 | 
				
			||||||
 | 
					        #           obj["geometry"].exec_three_obj_method('computeVertexNormals')
 | 
				
			||||||
 | 
					        if type(colors) != type(None):
 | 
				
			||||||
 | 
					            colors, coloring = self.__get_colors(obj["arrays"][0], obj["arrays"][1], colors, obj["shading"])
 | 
				
			||||||
 | 
					            colors = colors.astype("float32", copy=False)
 | 
				
			||||||
 | 
					            obj["geometry"].attributes["color"].array = colors
 | 
				
			||||||
 | 
					            obj["geometry"].attributes["color"].needsUpdate = True
 | 
				
			||||||
 | 
					        if type(faces) != type(None):
 | 
				
			||||||
 | 
					            if obj["coloring"] == "FaceColors":
 | 
				
			||||||
 | 
					                print("Face updates are currently only possible in vertex color mode.")
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
 | 
					            f = faces.astype("uint32", copy=False).ravel()
 | 
				
			||||||
 | 
					            print(obj["geometry"].attributes)
 | 
				
			||||||
 | 
					            obj["geometry"].attributes["index"].array = f
 | 
				
			||||||
 | 
					            # self.wireframe.attributes["position"].array = v # Wireframe updates?
 | 
				
			||||||
 | 
					            obj["geometry"].attributes["index"].needsUpdate = True
 | 
				
			||||||
 | 
					        #            obj["geometry"].exec_three_obj_method('computeVertexNormals')
 | 
				
			||||||
 | 
					        # self.mesh.geometry.verticesNeedUpdate = True
 | 
				
			||||||
 | 
					        # self.mesh.geometry.elementsNeedUpdate = True
 | 
				
			||||||
 | 
					        # self.update()
 | 
				
			||||||
 | 
					        if self.render_mode == "WEBSITE":
 | 
				
			||||||
 | 
					            return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #    def update(self):
 | 
				
			||||||
 | 
					    #        self.mesh.exec_three_obj_method('update')
 | 
				
			||||||
 | 
					    #        self.orbit.exec_three_obj_method('update')
 | 
				
			||||||
 | 
					    #        self.cam.exec_three_obj_method('updateProjectionMatrix')
 | 
				
			||||||
 | 
					    #        self.scene.exec_three_obj_method('update')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_text(self, text, shading={}, **kwargs):
 | 
				
			||||||
 | 
					        shading.update(kwargs)
 | 
				
			||||||
 | 
					        sh = self.__get_shading(shading)
 | 
				
			||||||
 | 
					        tt = p3s.TextTexture(string=text, color=sh["text_color"])
 | 
				
			||||||
 | 
					        sm = p3s.SpriteMaterial(map=tt)
 | 
				
			||||||
 | 
					        text = p3s.Sprite(material=sm, scaleToTexture=True)
 | 
				
			||||||
 | 
					        self._scene.add(text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # def add_widget(self, widget, callback):
 | 
				
			||||||
 | 
					    #    self.widgets.append(widget)
 | 
				
			||||||
 | 
					    #    widget.observe(callback, names='value')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #    def add_dropdown(self, options, default, desc, cb):
 | 
				
			||||||
 | 
					    #        widget = widgets.Dropdown(options=options, value=default, description=desc)
 | 
				
			||||||
 | 
					    #        self.__widgets.append(widget)
 | 
				
			||||||
 | 
					    #        widget.observe(cb, names="value")
 | 
				
			||||||
 | 
					    #        display(widget)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #    def add_button(self, text, cb):
 | 
				
			||||||
 | 
					    #        button = widgets.Button(description=text)
 | 
				
			||||||
 | 
					    #        self.__widgets.append(button)
 | 
				
			||||||
 | 
					    #        button.on_click(cb)
 | 
				
			||||||
 | 
					    #        display(button)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def to_html(self, imports=True, html_frame=True):
 | 
				
			||||||
 | 
					        # Bake positions (fixes centering bug in offline rendering)
 | 
				
			||||||
 | 
					        if len(self.__objects) == 0:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        ma = np.zeros((len(self.__objects), 3))
 | 
				
			||||||
 | 
					        mi = np.zeros((len(self.__objects), 3))
 | 
				
			||||||
 | 
					        for r, obj in enumerate(self.__objects):
 | 
				
			||||||
 | 
					            ma[r] = self.__objects[obj]["max"]
 | 
				
			||||||
 | 
					            mi[r] = self.__objects[obj]["min"]
 | 
				
			||||||
 | 
					        ma = np.max(ma, axis=0)
 | 
				
			||||||
 | 
					        mi = np.min(mi, axis=0)
 | 
				
			||||||
 | 
					        diag = np.linalg.norm(ma - mi)
 | 
				
			||||||
 | 
					        mean = (ma - mi) / 2 + mi
 | 
				
			||||||
 | 
					        for r, obj in enumerate(self.__objects):
 | 
				
			||||||
 | 
					            v = self.__objects[obj]["geometry"].attributes["position"].array
 | 
				
			||||||
 | 
					            v -= mean
 | 
				
			||||||
 | 
					            v += np.array([0.0, .9, 0.0]) #! to move the obj to the center of window
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scale = self.__s["scale"] * (diag)
 | 
				
			||||||
 | 
					        self._orbit.target = [0.0, 0.0, 0.0]
 | 
				
			||||||
 | 
					        self._cam.lookAt([0.0, 0.0, 0.0])
 | 
				
			||||||
 | 
					        # self._cam.position = [0.0, 0.0, scale]
 | 
				
			||||||
 | 
					        self._cam.position = [0.0, 0.5, scale * 1.3] #! show four complete meshes in the window
 | 
				
			||||||
 | 
					        self._light.position = [0.0, 0.0, scale]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        state = embed.dependency_state(self._renderer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Somehow these entries are missing when the state is exported in python.
 | 
				
			||||||
 | 
					        # Exporting from the GUI works, so we are inserting the missing entries.
 | 
				
			||||||
 | 
					        for k in state:
 | 
				
			||||||
 | 
					            if state[k]["model_name"] == "OrbitControlsModel":
 | 
				
			||||||
 | 
					                state[k]["state"]["maxAzimuthAngle"] = "inf"
 | 
				
			||||||
 | 
					                state[k]["state"]["maxDistance"] = "inf"
 | 
				
			||||||
 | 
					                state[k]["state"]["maxZoom"] = "inf"
 | 
				
			||||||
 | 
					                state[k]["state"]["minAzimuthAngle"] = "-inf"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tpl = embed.load_requirejs_template
 | 
				
			||||||
 | 
					        if not imports:
 | 
				
			||||||
 | 
					            embed.load_requirejs_template = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        s = embed.embed_snippet(self._renderer, state=state, embed_url=EMBED_URL)
 | 
				
			||||||
 | 
					        # s = embed.embed_snippet(self.__w, state=state)
 | 
				
			||||||
 | 
					        embed.load_requirejs_template = tpl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if html_frame:
 | 
				
			||||||
 | 
					            s = "<html>\n<body>\n" + s + "\n</body>\n</html>"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Revert changes
 | 
				
			||||||
 | 
					        for r, obj in enumerate(self.__objects):
 | 
				
			||||||
 | 
					            v = self.__objects[obj]["geometry"].attributes["position"].array
 | 
				
			||||||
 | 
					            v += mean
 | 
				
			||||||
 | 
					        self.__update_view()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def save(self, filename=""):
 | 
				
			||||||
 | 
					        if filename == "":
 | 
				
			||||||
 | 
					            uid = str(uuid.uuid4()) + ".html"
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            filename = filename.replace(".html", "")
 | 
				
			||||||
 | 
					            uid = filename + '.html'
 | 
				
			||||||
 | 
					        with open(uid, "w") as f:
 | 
				
			||||||
 | 
					            f.write(self.to_html())
 | 
				
			||||||
 | 
					        print("Plot saved to file %s." % uid)
 | 
				
			||||||
							
								
								
									
										87
									
								
								primitive_anything/primitive_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								primitive_anything/primitive_dataset.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,87 @@
 | 
				
			|||||||
 | 
					import copy
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np  
 | 
				
			||||||
 | 
					from scipy.linalg import polar
 | 
				
			||||||
 | 
					from scipy.spatial.transform import Rotation
 | 
				
			||||||
 | 
					import open3d as o3d
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.utils.data import Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .utils import exists
 | 
				
			||||||
 | 
					from .utils.logger import print_log
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_dataset(cfg_dataset):
 | 
				
			||||||
 | 
					    kwargs = cfg_dataset
 | 
				
			||||||
 | 
					    name = kwargs.pop('name')
 | 
				
			||||||
 | 
					    dataset = get_dataset(name)(**kwargs)
 | 
				
			||||||
 | 
					    print_log(f"Dataset '{name}' init: kwargs={kwargs}, len={len(dataset)}")
 | 
				
			||||||
 | 
					    return dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_dataset(name):
 | 
				
			||||||
 | 
					    return {
 | 
				
			||||||
 | 
					        'base': PrimitiveDataset,
 | 
				
			||||||
 | 
					    }[name]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SHAPE_CODE = {
 | 
				
			||||||
 | 
					    'CubeBevel': 0,
 | 
				
			||||||
 | 
					    'SphereSharp': 1,
 | 
				
			||||||
 | 
					    'CylinderSharp': 2,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PrimitiveDataset(Dataset): 
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					        pc_dir,
 | 
				
			||||||
 | 
					        bs_dir,
 | 
				
			||||||
 | 
					        max_length=144,
 | 
				
			||||||
 | 
					        range_scale=[0, 1],
 | 
				
			||||||
 | 
					        range_rotation=[-180, 180],
 | 
				
			||||||
 | 
					        range_translation=[-1, 1],
 | 
				
			||||||
 | 
					        rotation_type='euler',
 | 
				
			||||||
 | 
					        pc_format='pc',
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        self.data_filename = os.listdir(pc_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.pc_dir = pc_dir
 | 
				
			||||||
 | 
					        self.max_length = max_length
 | 
				
			||||||
 | 
					        self.range_scale = range_scale
 | 
				
			||||||
 | 
					        self.range_rotation = range_rotation
 | 
				
			||||||
 | 
					        self.range_translation = range_translation
 | 
				
			||||||
 | 
					        self.rotation_type = rotation_type
 | 
				
			||||||
 | 
					        self.pc_format = pc_format
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with open(os.path.join(bs_dir, 'basic_shapes.json'), 'r', encoding='utf-8') as f:
 | 
				
			||||||
 | 
					            basic_shapes = json.load(f)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        self.typeid_map = {
 | 
				
			||||||
 | 
					            1101002001034001: 'CubeBevel',
 | 
				
			||||||
 | 
					            1101002001034010: 'SphereSharp',
 | 
				
			||||||
 | 
					            1101002001034002: 'CylinderSharp',
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return len(self.data_filename)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, idx):
 | 
				
			||||||
 | 
					        pc_file = os.path.join(self.pc_dir, self.data_filename[idx])
 | 
				
			||||||
 | 
					        pc = o3d.io.read_point_cloud(pc_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model_data = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        points = torch.from_numpy(np.asarray(pc.points)).float()
 | 
				
			||||||
 | 
					        colors = torch.from_numpy(np.asarray(pc.colors)).float()
 | 
				
			||||||
 | 
					        normals = torch.from_numpy(np.asarray(pc.normals)).float()
 | 
				
			||||||
 | 
					        if self.pc_format == 'pc':
 | 
				
			||||||
 | 
					            model_data['pc'] = torch.concatenate([points, colors], dim=-1).T
 | 
				
			||||||
 | 
					        elif self.pc_format == 'pn':
 | 
				
			||||||
 | 
					            model_data['pc'] = torch.concatenate([points, normals], dim=-1)
 | 
				
			||||||
 | 
					        elif self.pc_format == 'pcn':
 | 
				
			||||||
 | 
					            model_data['pc'] = torch.concatenate([points, colors, normals], dim=-1)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ValueError(f'invalid pc_format: {self.pc_format}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return model_data
 | 
				
			||||||
							
								
								
									
										948
									
								
								primitive_anything/primitive_transformer.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										948
									
								
								primitive_anything/primitive_transformer.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,948 @@
 | 
				
			|||||||
 | 
					from __future__ import annotations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					from math import ceil
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from accelerate.utils import DistributedDataParallelKwargs
 | 
				
			||||||
 | 
					from beartype.typing import Tuple, Callable, List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from einops import rearrange, repeat, reduce, pack
 | 
				
			||||||
 | 
					from gateloop_transformer import SimpleGateLoopLayer
 | 
				
			||||||
 | 
					from huggingface_hub import PyTorchModelHubMixin
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import open3d as o3d
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn, Tensor
 | 
				
			||||||
 | 
					from torch.nn import Module, ModuleList
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					from pytorch3d.loss import chamfer_distance
 | 
				
			||||||
 | 
					from pytorch3d.transforms import euler_angles_to_matrix
 | 
				
			||||||
 | 
					from x_transformers import Decoder
 | 
				
			||||||
 | 
					from x_transformers.x_transformers import LayerIntermediates
 | 
				
			||||||
 | 
					from x_transformers.autoregressive_wrapper import eval_decorator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .michelangelo import ShapeConditioner as ShapeConditioner_miche
 | 
				
			||||||
 | 
					from .utils import (
 | 
				
			||||||
 | 
					    discretize,
 | 
				
			||||||
 | 
					    undiscretize,
 | 
				
			||||||
 | 
					    set_module_requires_grad_,
 | 
				
			||||||
 | 
					    default,
 | 
				
			||||||
 | 
					    exists,
 | 
				
			||||||
 | 
					    safe_cat,
 | 
				
			||||||
 | 
					    identity,
 | 
				
			||||||
 | 
					    is_tensor_empty,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from .utils.typing import Float, Int, Bool, typecheck
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# constants
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(
 | 
				
			||||||
 | 
					    find_unused_parameters = True
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					SHAPE_CODE = {
 | 
				
			||||||
 | 
					    'CubeBevel': 0,
 | 
				
			||||||
 | 
					    'SphereSharp': 1,
 | 
				
			||||||
 | 
					    'CylinderSharp': 2,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					BS_NAME = {
 | 
				
			||||||
 | 
					    0: 'CubeBevel',
 | 
				
			||||||
 | 
					    1: 'SphereSharp',
 | 
				
			||||||
 | 
					    2: 'CylinderSharp',
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# FiLM block
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FiLM(Module):
 | 
				
			||||||
 | 
					    def __init__(self, dim, dim_out = None):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        dim_out = default(dim_out, dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.to_gamma = nn.Linear(dim, dim_out, bias = False)
 | 
				
			||||||
 | 
					        self.to_beta = nn.Linear(dim, dim_out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.gamma_mult = nn.Parameter(torch.zeros(1,))
 | 
				
			||||||
 | 
					        self.beta_mult = nn.Parameter(torch.zeros(1,))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x, cond):
 | 
				
			||||||
 | 
					        gamma, beta = self.to_gamma(cond), self.to_beta(cond)
 | 
				
			||||||
 | 
					        gamma, beta = tuple(rearrange(t, 'b d -> b 1 d') for t in (gamma, beta))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # for initializing to identity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        gamma = (1 + self.gamma_mult * gamma.tanh())
 | 
				
			||||||
 | 
					        beta = beta.tanh() * self.beta_mult
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # classic film
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return x * gamma + beta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# gateloop layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GateLoopBlock(Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        dim,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        depth,
 | 
				
			||||||
 | 
					        use_heinsen = True
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.gateloops = ModuleList([])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for _ in range(depth):
 | 
				
			||||||
 | 
					            gateloop = SimpleGateLoopLayer(dim = dim, use_heinsen = use_heinsen)
 | 
				
			||||||
 | 
					            self.gateloops.append(gateloop)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        x,
 | 
				
			||||||
 | 
					        cache = None
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        received_cache = exists(cache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if is_tensor_empty(x):
 | 
				
			||||||
 | 
					            return x, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if received_cache:
 | 
				
			||||||
 | 
					            prev, x = x[:, :-1], x[:, -1:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cache = default(cache, [])
 | 
				
			||||||
 | 
					        cache = iter(cache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_caches = []
 | 
				
			||||||
 | 
					        for gateloop in self.gateloops:
 | 
				
			||||||
 | 
					            layer_cache = next(cache, None)
 | 
				
			||||||
 | 
					            out, new_cache = gateloop(x, cache = layer_cache, return_cache = True)
 | 
				
			||||||
 | 
					            new_caches.append(new_cache)
 | 
				
			||||||
 | 
					            x = x + out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if received_cache:
 | 
				
			||||||
 | 
					            x = torch.cat((prev, x), dim = -2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return x, new_caches
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def top_k_2(logits, frac_num_tokens=0.1, k=None):
 | 
				
			||||||
 | 
					    num_tokens = logits.shape[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    k = default(k, ceil(frac_num_tokens * num_tokens))
 | 
				
			||||||
 | 
					    k = min(k, num_tokens)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    val, ind = torch.topk(logits, k)
 | 
				
			||||||
 | 
					    probs = torch.full_like(logits, float('-inf'))
 | 
				
			||||||
 | 
					    probs.scatter_(2, ind, val)
 | 
				
			||||||
 | 
					    return probs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def soft_argmax(labels):
 | 
				
			||||||
 | 
					    indices = torch.arange(labels.size(-1), dtype=labels.dtype, device=labels.device)
 | 
				
			||||||
 | 
					    soft_argmax = torch.sum(labels * indices, dim=-1)
 | 
				
			||||||
 | 
					    return soft_argmax
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PrimitiveTransformerDiscrete(Module, PyTorchModelHubMixin):
 | 
				
			||||||
 | 
					    @typecheck
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        num_discrete_scale = 128,
 | 
				
			||||||
 | 
					        continuous_range_scale: List[float, float] = [0, 1],
 | 
				
			||||||
 | 
					        dim_scale_embed = 64,
 | 
				
			||||||
 | 
					        num_discrete_rotation = 180,
 | 
				
			||||||
 | 
					        continuous_range_rotation: List[float, float] = [-180, 180],
 | 
				
			||||||
 | 
					        dim_rotation_embed = 64,
 | 
				
			||||||
 | 
					        num_discrete_translation = 128,
 | 
				
			||||||
 | 
					        continuous_range_translation: List[float, float] = [-1, 1],
 | 
				
			||||||
 | 
					        dim_translation_embed = 64,
 | 
				
			||||||
 | 
					        num_type = 3,
 | 
				
			||||||
 | 
					        dim_type_embed = 64,
 | 
				
			||||||
 | 
					        embed_order = 'ctrs',
 | 
				
			||||||
 | 
					        bin_smooth_blur_sigma = 0.4,
 | 
				
			||||||
 | 
					        dim: int | Tuple[int, int] = 512,
 | 
				
			||||||
 | 
					        flash_attn = True,
 | 
				
			||||||
 | 
					        attn_depth = 12,
 | 
				
			||||||
 | 
					        attn_dim_head = 64,
 | 
				
			||||||
 | 
					        attn_heads = 16,
 | 
				
			||||||
 | 
					        attn_kwargs: dict = dict(
 | 
				
			||||||
 | 
					            ff_glu = True,
 | 
				
			||||||
 | 
					            attn_num_mem_kv = 4
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        max_primitive_len = 144,
 | 
				
			||||||
 | 
					        dropout = 0.,
 | 
				
			||||||
 | 
					        coarse_pre_gateloop_depth = 2,
 | 
				
			||||||
 | 
					        coarse_post_gateloop_depth = 0,
 | 
				
			||||||
 | 
					        coarse_adaptive_rmsnorm = False,
 | 
				
			||||||
 | 
					        gateloop_use_heinsen = False,
 | 
				
			||||||
 | 
					        pad_id = -1,
 | 
				
			||||||
 | 
					        num_sos_tokens = None,
 | 
				
			||||||
 | 
					        condition_on_shape = True,
 | 
				
			||||||
 | 
					        shape_cond_with_cross_attn = False,
 | 
				
			||||||
 | 
					        shape_cond_with_film = False,
 | 
				
			||||||
 | 
					        shape_cond_with_cat = False,
 | 
				
			||||||
 | 
					        shape_condition_model_type = 'michelangelo',
 | 
				
			||||||
 | 
					        shape_condition_len = 1,
 | 
				
			||||||
 | 
					        shape_condition_dim = None,
 | 
				
			||||||
 | 
					        cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out shape condition
 | 
				
			||||||
 | 
					        loss_weight: dict = dict(
 | 
				
			||||||
 | 
					            eos = 1.0,
 | 
				
			||||||
 | 
					            type = 1.0,
 | 
				
			||||||
 | 
					            scale = 1.0,
 | 
				
			||||||
 | 
					            rotation = 1.0,
 | 
				
			||||||
 | 
					            translation = 1.0,
 | 
				
			||||||
 | 
					            reconstruction = 1.0,
 | 
				
			||||||
 | 
					            scale_huber = 1.0,
 | 
				
			||||||
 | 
					            rotation_huber = 1.0,
 | 
				
			||||||
 | 
					            translation_huber = 1.0,
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        bs_pc_dir=None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # feature embedding
 | 
				
			||||||
 | 
					        self.num_discrete_scale = num_discrete_scale
 | 
				
			||||||
 | 
					        self.continuous_range_scale = continuous_range_scale
 | 
				
			||||||
 | 
					        self.discretize_scale = partial(discretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale)
 | 
				
			||||||
 | 
					        self.undiscretize_scale = partial(undiscretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale)
 | 
				
			||||||
 | 
					        self.scale_embed = nn.Embedding(num_discrete_scale, dim_scale_embed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.num_discrete_rotation = num_discrete_rotation
 | 
				
			||||||
 | 
					        self.continuous_range_rotation = continuous_range_rotation
 | 
				
			||||||
 | 
					        self.discretize_rotation = partial(discretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation)
 | 
				
			||||||
 | 
					        self.undiscretize_rotation = partial(undiscretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation)
 | 
				
			||||||
 | 
					        self.rotation_embed = nn.Embedding(num_discrete_rotation, dim_rotation_embed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.num_discrete_translation = num_discrete_translation
 | 
				
			||||||
 | 
					        self.continuous_range_translation = continuous_range_translation
 | 
				
			||||||
 | 
					        self.discretize_translation = partial(discretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation)
 | 
				
			||||||
 | 
					        self.undiscretize_translation = partial(undiscretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation)
 | 
				
			||||||
 | 
					        self.translation_embed = nn.Embedding(num_discrete_translation, dim_translation_embed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.num_type = num_type
 | 
				
			||||||
 | 
					        self.type_embed = nn.Embedding(num_type, dim_type_embed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.embed_order = embed_order
 | 
				
			||||||
 | 
					        self.bin_smooth_blur_sigma = bin_smooth_blur_sigma
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # initial dimension
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.dim = dim
 | 
				
			||||||
 | 
					        init_dim = 3 * (dim_scale_embed + dim_rotation_embed + dim_translation_embed) + dim_type_embed
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # project into model dimension
 | 
				
			||||||
 | 
					        self.project_in = nn.Linear(init_dim, dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        num_sos_tokens = default(num_sos_tokens, 1 if not condition_on_shape or not shape_cond_with_film else 4)
 | 
				
			||||||
 | 
					        assert num_sos_tokens > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.num_sos_tokens = num_sos_tokens
 | 
				
			||||||
 | 
					        self.sos_token = nn.Parameter(torch.randn(num_sos_tokens, dim))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # the transformer eos token
 | 
				
			||||||
 | 
					        self.eos_token = nn.Parameter(torch.randn(1, dim))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.emb_layernorm = nn.LayerNorm(dim)
 | 
				
			||||||
 | 
					        self.max_seq_len = max_primitive_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # shape condition
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.condition_on_shape = condition_on_shape
 | 
				
			||||||
 | 
					        self.shape_cond_with_cross_attn = False
 | 
				
			||||||
 | 
					        self.shape_cond_with_cat = False
 | 
				
			||||||
 | 
					        self.shape_condition_model_type = ''
 | 
				
			||||||
 | 
					        self.conditioner = None
 | 
				
			||||||
 | 
					        dim_shape = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if condition_on_shape:
 | 
				
			||||||
 | 
					            assert shape_cond_with_cross_attn or shape_cond_with_film or shape_cond_with_cat
 | 
				
			||||||
 | 
					            self.shape_cond_with_cross_attn = shape_cond_with_cross_attn
 | 
				
			||||||
 | 
					            self.shape_cond_with_cat = shape_cond_with_cat
 | 
				
			||||||
 | 
					            self.shape_condition_model_type = shape_condition_model_type
 | 
				
			||||||
 | 
					            if 'michelangelo' in shape_condition_model_type:
 | 
				
			||||||
 | 
					                self.conditioner = ShapeConditioner_miche(dim_latent=shape_condition_dim)
 | 
				
			||||||
 | 
					                self.to_cond_dim = nn.Linear(self.conditioner.dim_model_out * 2, self.conditioner.dim_latent)
 | 
				
			||||||
 | 
					                self.to_cond_dim_head = nn.Linear(self.conditioner.dim_model_out, self.conditioner.dim_latent)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            dim_shape = self.conditioner.dim_latent
 | 
				
			||||||
 | 
					            set_module_requires_grad_(self.conditioner, False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.shape_coarse_film_cond = FiLM(dim_shape, dim) if shape_cond_with_film else identity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.coarse_gateloop_block = GateLoopBlock(dim, depth=coarse_pre_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_pre_gateloop_depth > 0 else None
 | 
				
			||||||
 | 
					        self.coarse_post_gateloop_block = GateLoopBlock(dim, depth=coarse_post_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_post_gateloop_depth > 0 else None
 | 
				
			||||||
 | 
					        self.coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.decoder = Decoder(
 | 
				
			||||||
 | 
					            dim=dim,
 | 
				
			||||||
 | 
					            depth=attn_depth,
 | 
				
			||||||
 | 
					            heads=attn_heads,
 | 
				
			||||||
 | 
					            attn_dim_head=attn_dim_head,
 | 
				
			||||||
 | 
					            attn_flash=flash_attn,
 | 
				
			||||||
 | 
					            attn_dropout=dropout,
 | 
				
			||||||
 | 
					            ff_dropout=dropout,
 | 
				
			||||||
 | 
					            use_adaptive_rmsnorm=coarse_adaptive_rmsnorm,
 | 
				
			||||||
 | 
					            dim_condition=dim_shape,
 | 
				
			||||||
 | 
					            cross_attend=self.shape_cond_with_cross_attn,
 | 
				
			||||||
 | 
					            cross_attn_dim_context=dim_shape,
 | 
				
			||||||
 | 
					            cross_attn_num_mem_kv=cross_attn_num_mem_kv,
 | 
				
			||||||
 | 
					            **attn_kwargs
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # to logits
 | 
				
			||||||
 | 
					        self.to_eos_logits = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.Linear(dim, dim),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					            nn.Linear(dim, 1)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.to_type_logits = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.Linear(dim, dim),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					            nn.Linear(dim, num_type)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.to_translation_logits = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.Linear(dim + dim_type_embed, dim),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					            nn.Linear(dim, 3 * num_discrete_translation)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.to_rotation_logits = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.Linear(dim + dim_type_embed + 3 * dim_translation_embed, dim),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					            nn.Linear(dim, 3 * num_discrete_rotation)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.to_scale_logits = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.Linear(dim + dim_type_embed + 3 * (dim_translation_embed + dim_rotation_embed), dim),
 | 
				
			||||||
 | 
					            nn.ReLU(),
 | 
				
			||||||
 | 
					            nn.Linear(dim, 3 * num_discrete_scale)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.pad_id = pad_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        bs_pc_map = {}
 | 
				
			||||||
 | 
					        for bs_name, type_code in SHAPE_CODE.items():
 | 
				
			||||||
 | 
					            pc = o3d.io.read_point_cloud(os.path.join(bs_pc_dir, f'SM_GR_BS_{bs_name}_001.ply'))
 | 
				
			||||||
 | 
					            bs_pc_map[type_code] = torch.from_numpy(np.asarray(pc.points)).float()
 | 
				
			||||||
 | 
					        bs_pc_list = []
 | 
				
			||||||
 | 
					        for i in range(len(bs_pc_map)):
 | 
				
			||||||
 | 
					            bs_pc_list.append(bs_pc_map[i])
 | 
				
			||||||
 | 
					        self.bs_pc = torch.stack(bs_pc_list, dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.rotation_matrix_align_coord = euler_angles_to_matrix(
 | 
				
			||||||
 | 
					                                            torch.Tensor([np.pi/2, 0, 0]), 'XYZ').unsqueeze(0).unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def device(self):
 | 
				
			||||||
 | 
					        return next(self.parameters()).device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @typecheck
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    def embed_pc(self, pc: Tensor):
 | 
				
			||||||
 | 
					        if 'michelangelo' in self.shape_condition_model_type:
 | 
				
			||||||
 | 
					            pc_head, pc_embed = self.conditioner(shape=pc)
 | 
				
			||||||
 | 
					            pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2).detach()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return pc_embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @typecheck
 | 
				
			||||||
 | 
					    def recon_primitives(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        scale_logits: Float['b np 3 nd'],
 | 
				
			||||||
 | 
					        rotation_logits: Float['b np 3 nd'],
 | 
				
			||||||
 | 
					        translation_logits: Float['b np 3 nd'],
 | 
				
			||||||
 | 
					        type_logits: Int['b np nd'],
 | 
				
			||||||
 | 
					        primitive_mask: Bool['b np']
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        recon_scale = self.undiscretize_scale(scale_logits.argmax(dim=-1))
 | 
				
			||||||
 | 
					        recon_scale = recon_scale.masked_fill(~primitive_mask.unsqueeze(-1), float('nan'))
 | 
				
			||||||
 | 
					        recon_rotation = self.undiscretize_rotation(rotation_logits.argmax(dim=-1))
 | 
				
			||||||
 | 
					        recon_rotation = recon_rotation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan'))
 | 
				
			||||||
 | 
					        recon_translation = self.undiscretize_translation(translation_logits.argmax(dim=-1))
 | 
				
			||||||
 | 
					        recon_translation = recon_translation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan'))
 | 
				
			||||||
 | 
					        recon_type_code = type_logits.argmax(dim=-1)
 | 
				
			||||||
 | 
					        recon_type_code = recon_type_code.masked_fill(~primitive_mask, -1)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            'scale': recon_scale,
 | 
				
			||||||
 | 
					            'rotation': recon_rotation,
 | 
				
			||||||
 | 
					            'translation': recon_translation,
 | 
				
			||||||
 | 
					            'type_code': recon_type_code
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    @typecheck
 | 
				
			||||||
 | 
					    def sample_primitives(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        scale: Float['b np 3 nd'],
 | 
				
			||||||
 | 
					        rotation: Float['b np 3 nd'],
 | 
				
			||||||
 | 
					        translation: Float['b np 3 nd'],
 | 
				
			||||||
 | 
					        type_code: Int['b np nd'],
 | 
				
			||||||
 | 
					        next_embed: Float['b 1 nd'],
 | 
				
			||||||
 | 
					        temperature: float = 1.,
 | 
				
			||||||
 | 
					        filter_logits_fn: Callable = top_k_2,
 | 
				
			||||||
 | 
					        filter_kwargs: dict = dict()
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        def sample_func(logits):
 | 
				
			||||||
 | 
					            if logits.ndim == 4:
 | 
				
			||||||
 | 
					                enable_squeeze = True
 | 
				
			||||||
 | 
					                logits = logits.squeeze(1)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                enable_squeeze = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            filtered_logits = filter_logits_fn(logits, **filter_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if temperature == 0.:
 | 
				
			||||||
 | 
					                sample = filtered_logits.argmax(dim=-1)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                probs = F.softmax(filtered_logits / temperature, dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                sample = torch.zeros((probs.shape[0], probs.shape[1]), dtype=torch.long, device=probs.device)
 | 
				
			||||||
 | 
					                for b_i in range(probs.shape[0]):
 | 
				
			||||||
 | 
					                    sample[b_i] = torch.multinomial(probs[b_i], 1).squeeze()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if enable_squeeze:
 | 
				
			||||||
 | 
					                sample = sample.unsqueeze(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return sample
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        next_type_logits = self.to_type_logits(next_embed)
 | 
				
			||||||
 | 
					        next_type_code = sample_func(next_type_logits)
 | 
				
			||||||
 | 
					        type_code_new, _ = pack([type_code, next_type_code], 'b *')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        type_embed = self.type_embed(next_type_code)
 | 
				
			||||||
 | 
					        next_embed_packed, _ = pack([next_embed, type_embed], 'b np *')
 | 
				
			||||||
 | 
					        next_translation_logits = rearrange(self.to_translation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_translation)
 | 
				
			||||||
 | 
					        next_discretize_translation = sample_func(next_translation_logits)
 | 
				
			||||||
 | 
					        next_translation = self.undiscretize_translation(next_discretize_translation)
 | 
				
			||||||
 | 
					        translation_new, _ = pack([translation, next_translation], 'b * nd')
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        next_translation_embed = self.translation_embed(next_discretize_translation)
 | 
				
			||||||
 | 
					        next_embed_packed, _ = pack([next_embed_packed, next_translation_embed], 'b np *')
 | 
				
			||||||
 | 
					        next_rotation_logits = rearrange(self.to_rotation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_rotation)
 | 
				
			||||||
 | 
					        next_discretize_rotation = sample_func(next_rotation_logits)
 | 
				
			||||||
 | 
					        next_rotation = self.undiscretize_rotation(next_discretize_rotation)
 | 
				
			||||||
 | 
					        rotation_new, _ = pack([rotation, next_rotation], 'b * nd')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        next_rotation_embed = self.rotation_embed(next_discretize_rotation)
 | 
				
			||||||
 | 
					        next_embed_packed, _ = pack([next_embed_packed, next_rotation_embed], 'b np *')
 | 
				
			||||||
 | 
					        next_scale_logits = rearrange(self.to_scale_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_scale)
 | 
				
			||||||
 | 
					        next_discretize_scale = sample_func(next_scale_logits)
 | 
				
			||||||
 | 
					        next_scale = self.undiscretize_scale(next_discretize_scale)
 | 
				
			||||||
 | 
					        scale_new, _ = pack([scale, next_scale], 'b * nd')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return (
 | 
				
			||||||
 | 
					            scale_new,
 | 
				
			||||||
 | 
					            rotation_new,
 | 
				
			||||||
 | 
					            translation_new,
 | 
				
			||||||
 | 
					            type_code_new
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @eval_decorator
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    @typecheck
 | 
				
			||||||
 | 
					    def generate(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        batch_size: int | None = None,
 | 
				
			||||||
 | 
					        filter_logits_fn: Callable = top_k_2,
 | 
				
			||||||
 | 
					        filter_kwargs: dict = dict(),
 | 
				
			||||||
 | 
					        temperature: float = 1.,
 | 
				
			||||||
 | 
					        scale: Float['b np 3'] | None = None,
 | 
				
			||||||
 | 
					        rotation: Float['b np 3'] | None = None,
 | 
				
			||||||
 | 
					        translation: Float['b np 3'] | None = None,
 | 
				
			||||||
 | 
					        type_code: Int['b np'] | None = None,
 | 
				
			||||||
 | 
					        pc: Tensor | None = None,
 | 
				
			||||||
 | 
					        pc_embed: Tensor | None = None,
 | 
				
			||||||
 | 
					        cache_kv = True,
 | 
				
			||||||
 | 
					        max_seq_len = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        max_seq_len = default(max_seq_len, self.max_seq_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if exists(scale) and exists(rotation) and exists(translation) and exists(type_code):
 | 
				
			||||||
 | 
					            assert not exists(batch_size)
 | 
				
			||||||
 | 
					            assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1]
 | 
				
			||||||
 | 
					            assert scale.shape[1] <= self.max_seq_len
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            batch_size = scale.shape[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.condition_on_shape:
 | 
				
			||||||
 | 
					            assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in'
 | 
				
			||||||
 | 
					            if exists(pc):
 | 
				
			||||||
 | 
					                pc_embed = self.embed_pc(pc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            batch_size = default(batch_size, pc_embed.shape[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_size = default(batch_size, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device))
 | 
				
			||||||
 | 
					        rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device))
 | 
				
			||||||
 | 
					        translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device))
 | 
				
			||||||
 | 
					        type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        curr_length = scale.shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cache = None
 | 
				
			||||||
 | 
					        eos_codes = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i in tqdm(range(curr_length, max_seq_len)):
 | 
				
			||||||
 | 
					            can_eos = i != 0
 | 
				
			||||||
 | 
					            output = self.forward(
 | 
				
			||||||
 | 
					                scale=scale,
 | 
				
			||||||
 | 
					                rotation=rotation,
 | 
				
			||||||
 | 
					                translation=translation,
 | 
				
			||||||
 | 
					                type_code=type_code,
 | 
				
			||||||
 | 
					                pc_embed=pc_embed,
 | 
				
			||||||
 | 
					                return_loss=False,
 | 
				
			||||||
 | 
					                return_cache=cache_kv,
 | 
				
			||||||
 | 
					                append_eos=False,
 | 
				
			||||||
 | 
					                cache=cache
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            if cache_kv:
 | 
				
			||||||
 | 
					                next_embed, cache = output
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                next_embed = output
 | 
				
			||||||
 | 
					            (
 | 
				
			||||||
 | 
					                scale,
 | 
				
			||||||
 | 
					                rotation,
 | 
				
			||||||
 | 
					                translation,
 | 
				
			||||||
 | 
					                type_code
 | 
				
			||||||
 | 
					            ) = self.sample_primitives(
 | 
				
			||||||
 | 
					                scale,
 | 
				
			||||||
 | 
					                rotation,
 | 
				
			||||||
 | 
					                translation,
 | 
				
			||||||
 | 
					                type_code,
 | 
				
			||||||
 | 
					                next_embed,
 | 
				
			||||||
 | 
					                temperature=temperature,
 | 
				
			||||||
 | 
					                filter_logits_fn=filter_logits_fn,
 | 
				
			||||||
 | 
					                filter_kwargs=filter_kwargs
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1)
 | 
				
			||||||
 | 
					            next_eos_code = (F.sigmoid(next_eos_logits) > 0.5)
 | 
				
			||||||
 | 
					            eos_codes = safe_cat([eos_codes, next_eos_code], 1)
 | 
				
			||||||
 | 
					            if can_eos and eos_codes.any(dim=-1).all():
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # mask out to padding anything after the first eos
 | 
				
			||||||
 | 
					        mask = eos_codes.float().cumsum(dim=-1) >= 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # concat cur_length to mask
 | 
				
			||||||
 | 
					        mask = torch.cat((torch.zeros((batch_size, curr_length), dtype=torch.bool, device=self.device), mask), dim=-1)
 | 
				
			||||||
 | 
					        type_code = type_code.masked_fill(mask, self.pad_id)
 | 
				
			||||||
 | 
					        scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id)
 | 
				
			||||||
 | 
					        rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id)
 | 
				
			||||||
 | 
					        translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        recon_primitives = {
 | 
				
			||||||
 | 
					            'scale': scale,
 | 
				
			||||||
 | 
					            'rotation': rotation,
 | 
				
			||||||
 | 
					            'translation': translation,
 | 
				
			||||||
 | 
					            'type_code': type_code
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        primitive_mask = ~eos_codes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return recon_primitives, primitive_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @eval_decorator
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    @typecheck
 | 
				
			||||||
 | 
					    def generate_w_recon_loss(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        batch_size: int | None = None,
 | 
				
			||||||
 | 
					        filter_logits_fn: Callable = top_k_2,
 | 
				
			||||||
 | 
					        filter_kwargs: dict = dict(),
 | 
				
			||||||
 | 
					        temperature: float = 1.,
 | 
				
			||||||
 | 
					        scale: Float['b np 3'] | None = None,
 | 
				
			||||||
 | 
					        rotation: Float['b np 3'] | None = None,
 | 
				
			||||||
 | 
					        translation: Float['b np 3'] | None = None,
 | 
				
			||||||
 | 
					        type_code: Int['b np'] | None = None,
 | 
				
			||||||
 | 
					        pc: Tensor | None = None,
 | 
				
			||||||
 | 
					        pc_embed: Tensor | None = None,
 | 
				
			||||||
 | 
					        cache_kv = True,
 | 
				
			||||||
 | 
					        max_seq_len = None,
 | 
				
			||||||
 | 
					        single_directional = True,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        max_seq_len = default(max_seq_len, self.max_seq_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if exists(scale) and exists(rotation) and exists(translation) and exists(type_code):
 | 
				
			||||||
 | 
					            assert not exists(batch_size)
 | 
				
			||||||
 | 
					            assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1]
 | 
				
			||||||
 | 
					            assert scale.shape[1] <= self.max_seq_len
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            batch_size = scale.shape[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.condition_on_shape:
 | 
				
			||||||
 | 
					            assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in'
 | 
				
			||||||
 | 
					            if exists(pc):
 | 
				
			||||||
 | 
					                pc_embed = self.embed_pc(pc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            batch_size = default(batch_size, pc_embed.shape[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_size = default(batch_size, 1)
 | 
				
			||||||
 | 
					        assert batch_size == 1 # TODO: support any batch size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device))
 | 
				
			||||||
 | 
					        rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device))
 | 
				
			||||||
 | 
					        translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device))
 | 
				
			||||||
 | 
					        type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        curr_length = scale.shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cache = None
 | 
				
			||||||
 | 
					        eos_codes = None
 | 
				
			||||||
 | 
					        last_recon_loss = 1
 | 
				
			||||||
 | 
					        for i in tqdm(range(curr_length, max_seq_len)):
 | 
				
			||||||
 | 
					            can_eos = i != 0
 | 
				
			||||||
 | 
					            output = self.forward(
 | 
				
			||||||
 | 
					                scale=scale,
 | 
				
			||||||
 | 
					                rotation=rotation,
 | 
				
			||||||
 | 
					                translation=translation,
 | 
				
			||||||
 | 
					                type_code=type_code,
 | 
				
			||||||
 | 
					                pc_embed=pc_embed,
 | 
				
			||||||
 | 
					                return_loss=False,
 | 
				
			||||||
 | 
					                return_cache=cache_kv,
 | 
				
			||||||
 | 
					                append_eos=False,
 | 
				
			||||||
 | 
					                cache=cache
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            if cache_kv:
 | 
				
			||||||
 | 
					                next_embed, cache = output
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                next_embed = output
 | 
				
			||||||
 | 
					            (
 | 
				
			||||||
 | 
					                scale_new,
 | 
				
			||||||
 | 
					                rotation_new,
 | 
				
			||||||
 | 
					                translation_new,
 | 
				
			||||||
 | 
					                type_code_new
 | 
				
			||||||
 | 
					            ) = self.sample_primitives(
 | 
				
			||||||
 | 
					                scale,
 | 
				
			||||||
 | 
					                rotation,
 | 
				
			||||||
 | 
					                translation,
 | 
				
			||||||
 | 
					                type_code,
 | 
				
			||||||
 | 
					                next_embed,
 | 
				
			||||||
 | 
					                temperature=temperature,
 | 
				
			||||||
 | 
					                filter_logits_fn=filter_logits_fn,
 | 
				
			||||||
 | 
					                filter_kwargs=filter_kwargs
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1)
 | 
				
			||||||
 | 
					            next_eos_code = (F.sigmoid(next_eos_logits) > 0.5)
 | 
				
			||||||
 | 
					            eos_codes = safe_cat([eos_codes, next_eos_code], 1)
 | 
				
			||||||
 | 
					            if can_eos and eos_codes.any(dim=-1).all():
 | 
				
			||||||
 | 
					                scale, rotation, translation, type_code = (
 | 
				
			||||||
 | 
					                    scale_new, rotation_new, translation_new, type_code_new)
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc, single_directional)
 | 
				
			||||||
 | 
					            if recon_loss < last_recon_loss:
 | 
				
			||||||
 | 
					                last_recon_loss = recon_loss
 | 
				
			||||||
 | 
					                scale, rotation, translation, type_code = (
 | 
				
			||||||
 | 
					                    scale_new, rotation_new, translation_new, type_code_new)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                best_recon_loss = recon_loss
 | 
				
			||||||
 | 
					                best_primitives = dict(
 | 
				
			||||||
 | 
					                    scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new)
 | 
				
			||||||
 | 
					                success_flag = False
 | 
				
			||||||
 | 
					                print(f'last_recon_loss:{last_recon_loss}, recon_loss:{recon_loss} -> to find better primitive')
 | 
				
			||||||
 | 
					                for try_i in range(5):
 | 
				
			||||||
 | 
					                    (
 | 
				
			||||||
 | 
					                        scale_new,
 | 
				
			||||||
 | 
					                        rotation_new,
 | 
				
			||||||
 | 
					                        translation_new,
 | 
				
			||||||
 | 
					                        type_code_new
 | 
				
			||||||
 | 
					                    ) = self.sample_primitives(
 | 
				
			||||||
 | 
					                        scale,
 | 
				
			||||||
 | 
					                        rotation,
 | 
				
			||||||
 | 
					                        translation,
 | 
				
			||||||
 | 
					                        type_code,
 | 
				
			||||||
 | 
					                        next_embed,
 | 
				
			||||||
 | 
					                        temperature=1.0,
 | 
				
			||||||
 | 
					                        filter_logits_fn=filter_logits_fn,
 | 
				
			||||||
 | 
					                        filter_kwargs=filter_kwargs
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc)
 | 
				
			||||||
 | 
					                    print(f'[try_{try_i}] last_recon_loss:{last_recon_loss}, best_recon_loss:{best_recon_loss}, cur_recon_loss:{recon_loss}')
 | 
				
			||||||
 | 
					                    if recon_loss < last_recon_loss:
 | 
				
			||||||
 | 
					                        last_recon_loss = recon_loss
 | 
				
			||||||
 | 
					                        scale, rotation, translation, type_code = (
 | 
				
			||||||
 | 
					                            scale_new, rotation_new, translation_new, type_code_new)
 | 
				
			||||||
 | 
					                        success_flag = True
 | 
				
			||||||
 | 
					                        break
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        if recon_loss < best_recon_loss:
 | 
				
			||||||
 | 
					                            best_recon_loss = recon_loss
 | 
				
			||||||
 | 
					                            best_primitives = dict(
 | 
				
			||||||
 | 
					                                scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if not success_flag:
 | 
				
			||||||
 | 
					                    last_recon_loss = best_recon_loss
 | 
				
			||||||
 | 
					                    scale, rotation, translation, type_code = (
 | 
				
			||||||
 | 
					                        best_primitives['scale'], best_primitives['rotation'], best_primitives['translation'], best_primitives['type_code'])
 | 
				
			||||||
 | 
					                print(f'new_last_recon_loss:{last_recon_loss}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # mask out to padding anything after the first eos
 | 
				
			||||||
 | 
					        mask = eos_codes.float().cumsum(dim=-1) >= 1
 | 
				
			||||||
 | 
					        type_code = type_code.masked_fill(mask, self.pad_id)
 | 
				
			||||||
 | 
					        scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id)
 | 
				
			||||||
 | 
					        rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id)
 | 
				
			||||||
 | 
					        translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        recon_primitives = {
 | 
				
			||||||
 | 
					            'scale': scale,
 | 
				
			||||||
 | 
					            'rotation': rotation,
 | 
				
			||||||
 | 
					            'translation': translation,
 | 
				
			||||||
 | 
					            'type_code': type_code
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        primitive_mask = ~eos_codes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return recon_primitives, primitive_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @typecheck
 | 
				
			||||||
 | 
					    def encode(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        scale: Float['b np 3'],
 | 
				
			||||||
 | 
					        rotation: Float['b np 3'],
 | 
				
			||||||
 | 
					        translation: Float['b np 3'],
 | 
				
			||||||
 | 
					        type_code: Int['b np'],
 | 
				
			||||||
 | 
					        primitive_mask: Bool['b np'],
 | 
				
			||||||
 | 
					        return_primitives = False
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        einops:
 | 
				
			||||||
 | 
					        b - batch
 | 
				
			||||||
 | 
					        np - number of primitives
 | 
				
			||||||
 | 
					        c - coordinates (3)
 | 
				
			||||||
 | 
					        d - embed dim
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					        # compute feature embedding
 | 
				
			||||||
 | 
					        discretize_scale = self.discretize_scale(scale)
 | 
				
			||||||
 | 
					        scale_embed = self.scale_embed(discretize_scale)
 | 
				
			||||||
 | 
					        scale_embed = rearrange(scale_embed, 'b np c d -> b np (c d)')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        discretize_rotation = self.discretize_rotation(rotation)
 | 
				
			||||||
 | 
					        rotation_embed = self.rotation_embed(discretize_rotation)
 | 
				
			||||||
 | 
					        rotation_embed = rearrange(rotation_embed, 'b np c d -> b np (c d)')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        discretize_translation = self.discretize_translation(translation)
 | 
				
			||||||
 | 
					        translation_embed = self.translation_embed(discretize_translation)
 | 
				
			||||||
 | 
					        translation_embed = rearrange(translation_embed, 'b np c d -> b np (c d)')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        type_embed = self.type_embed(type_code.masked_fill(~primitive_mask, 0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # combine all features and project into model dimension
 | 
				
			||||||
 | 
					        if self.embed_order == 'srtc':
 | 
				
			||||||
 | 
					            primitive_embed, _ = pack([scale_embed, rotation_embed, translation_embed, type_embed], 'b np *')
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            primitive_embed, _ = pack([type_embed, translation_embed, rotation_embed, scale_embed], 'b np *')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        primitive_embed = self.project_in(primitive_embed)
 | 
				
			||||||
 | 
					        primitive_embed = primitive_embed.masked_fill(~primitive_mask.unsqueeze(-1), 0.)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not return_primitives:
 | 
				
			||||||
 | 
					            return primitive_embed
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        primitive_embed_unpacked = {
 | 
				
			||||||
 | 
					            'scale': scale_embed,
 | 
				
			||||||
 | 
					            'rotation': rotation_embed,
 | 
				
			||||||
 | 
					            'translation': translation_embed,
 | 
				
			||||||
 | 
					            'type_code': type_embed
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        primitives_gt = {
 | 
				
			||||||
 | 
					            'scale': discretize_scale,
 | 
				
			||||||
 | 
					            'rotation': discretize_rotation,
 | 
				
			||||||
 | 
					            'translation': discretize_translation,
 | 
				
			||||||
 | 
					            'type_code': type_code
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return primitive_embed, primitive_embed_unpacked, primitives_gt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @typecheck
 | 
				
			||||||
 | 
					    def compute_chamfer_distance(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        scale_pred: Float['b np 3'],
 | 
				
			||||||
 | 
					        rotation_pred: Float['b np 3'],
 | 
				
			||||||
 | 
					        translation_pred: Float['b np 3'],
 | 
				
			||||||
 | 
					        type_pred: Int['b np'],
 | 
				
			||||||
 | 
					        primitive_mask: Bool['b np'],
 | 
				
			||||||
 | 
					        pc: Tensor, # b, num_points, c
 | 
				
			||||||
 | 
					        single_directional = True
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        scale_pred = scale_pred.float()
 | 
				
			||||||
 | 
					        rotation_pred = rotation_pred.float()
 | 
				
			||||||
 | 
					        translation_pred = translation_pred.float()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pc_pred = apply_transformation(self.bs_pc.to(type_pred.device)[type_pred], scale_pred, torch.deg2rad(rotation_pred), translation_pred)
 | 
				
			||||||
 | 
					        pc_pred = torch.matmul(pc_pred, self.rotation_matrix_align_coord.to(type_pred.device))
 | 
				
			||||||
 | 
					        pc_pred_flat = rearrange(pc_pred, 'b np p c -> b (np p) c')
 | 
				
			||||||
 | 
					        pc_pred_sampled = random_sample_pc(pc_pred_flat, primitive_mask.sum(dim=-1, keepdim=True), n_points=self.bs_pc.shape[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if single_directional:
 | 
				
			||||||
 | 
					            recon_loss, _ = chamfer_distance(pc[:, :, :3].float(), pc_pred_sampled.float(), single_directional=True) # single directional
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            recon_loss, _ = chamfer_distance(pc_pred_sampled.float(), pc[:, :, :3].float())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return recon_loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        scale: Float['b np 3'],
 | 
				
			||||||
 | 
					        rotation: Float['b np 3'],
 | 
				
			||||||
 | 
					        translation: Float['b np 3'],
 | 
				
			||||||
 | 
					        type_code: Int['b np'],
 | 
				
			||||||
 | 
					        loss_reduction: str = 'mean',
 | 
				
			||||||
 | 
					        return_cache = False,
 | 
				
			||||||
 | 
					        append_eos = True,
 | 
				
			||||||
 | 
					        cache: LayerIntermediates | None = None,
 | 
				
			||||||
 | 
					        pc: Tensor | None = None,
 | 
				
			||||||
 | 
					        pc_embed: Tensor | None = None,
 | 
				
			||||||
 | 
					        **kwargs
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        primitive_mask = reduce(scale != self.pad_id, 'b np 3 -> b np', 'all')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if scale.shape[1] > 0:
 | 
				
			||||||
 | 
					            codes, primitives_embeds, primitives_gt = self.encode(
 | 
				
			||||||
 | 
					                scale=scale,
 | 
				
			||||||
 | 
					                rotation=rotation,
 | 
				
			||||||
 | 
					                translation=translation,
 | 
				
			||||||
 | 
					                type_code=type_code,
 | 
				
			||||||
 | 
					                primitive_mask=primitive_mask,
 | 
				
			||||||
 | 
					                return_primitives=True
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            codes = torch.empty((scale.shape[0], 0, self.dim), dtype=torch.float32, device=self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # handle shape conditions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attn_context_kwargs = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.condition_on_shape:
 | 
				
			||||||
 | 
					            assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if exists(pc):
 | 
				
			||||||
 | 
					                if 'michelangelo' in self.shape_condition_model_type:
 | 
				
			||||||
 | 
					                    pc_head, pc_embed = self.conditioner(shape=pc)
 | 
				
			||||||
 | 
					                    pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            assert pc_embed.shape[0] == codes.shape[0], 'batch size of point cloud is not equal to the batch size of the primitive codes'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            pooled_pc_embed = pc_embed.mean(dim=1) # (b, shape_condition_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.shape_cond_with_cross_attn:
 | 
				
			||||||
 | 
					                attn_context_kwargs = dict(
 | 
				
			||||||
 | 
					                    context=pc_embed
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.coarse_adaptive_rmsnorm:
 | 
				
			||||||
 | 
					                attn_context_kwargs.update(
 | 
				
			||||||
 | 
					                    condition=pooled_pc_embed
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch, seq_len, _ = codes.shape # (b, np, dim)
 | 
				
			||||||
 | 
					        device = codes.device
 | 
				
			||||||
 | 
					        assert seq_len <= self.max_seq_len, f'received codes of length {seq_len} but needs to be less than or equal to set max_seq_len {self.max_seq_len}'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if append_eos:
 | 
				
			||||||
 | 
					            assert exists(codes)
 | 
				
			||||||
 | 
					            code_lens = primitive_mask.sum(dim=-1)
 | 
				
			||||||
 | 
					            codes = pad_tensor(codes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            batch_arange = torch.arange(batch, device=device)
 | 
				
			||||||
 | 
					            batch_arange = rearrange(batch_arange, '... -> ... 1')
 | 
				
			||||||
 | 
					            code_lens = rearrange(code_lens, '... -> ... 1')
 | 
				
			||||||
 | 
					            codes[batch_arange, code_lens] = self.eos_token # (b, np+1, dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        primitive_codes = codes # (b, np, dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        primitive_codes_len = primitive_codes.shape[-2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        (
 | 
				
			||||||
 | 
					            coarse_cache,
 | 
				
			||||||
 | 
					            coarse_gateloop_cache,
 | 
				
			||||||
 | 
					            coarse_post_gateloop_cache,
 | 
				
			||||||
 | 
					        ) = cache if exists(cache) else ((None,) * 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not exists(cache):
 | 
				
			||||||
 | 
					            sos = repeat(self.sos_token, 'n d -> b n d', b=batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.shape_cond_with_cat:
 | 
				
			||||||
 | 
					                sos, _ = pack([pc_embed, sos], 'b * d')
 | 
				
			||||||
 | 
					            primitive_codes, packed_sos_shape = pack([sos, primitive_codes], 'b * d') # (b, n_sos+np, dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # condition primitive codes with shape if needed
 | 
				
			||||||
 | 
					        if self.condition_on_shape:
 | 
				
			||||||
 | 
					            primitive_codes = self.shape_coarse_film_cond(primitive_codes, pooled_pc_embed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # attention on primitive codes (coarse)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if exists(self.coarse_gateloop_block):
 | 
				
			||||||
 | 
					            primitive_codes, coarse_gateloop_cache = self.coarse_gateloop_block(primitive_codes, cache=coarse_gateloop_cache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attended_primitive_codes, coarse_cache = self.decoder( # (b, n_sos+np, dim) 
 | 
				
			||||||
 | 
					            primitive_codes,
 | 
				
			||||||
 | 
					            cache=coarse_cache,
 | 
				
			||||||
 | 
					            return_hiddens=True,
 | 
				
			||||||
 | 
					            **attn_context_kwargs
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if exists(self.coarse_post_gateloop_block):
 | 
				
			||||||
 | 
					            primitive_codes, coarse_post_gateloop_cache = self.coarse_post_gateloop_block(primitive_codes, cache=coarse_post_gateloop_cache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embed = attended_primitive_codes[:, -(primitive_codes_len + 1):] # (b, np+1, dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not return_cache:
 | 
				
			||||||
 | 
					            return embed[:, -1:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        next_cache = (
 | 
				
			||||||
 | 
					            coarse_cache,
 | 
				
			||||||
 | 
					            coarse_gateloop_cache,
 | 
				
			||||||
 | 
					            coarse_post_gateloop_cache
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return embed[:, -1:], next_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pad_tensor(tensor):
 | 
				
			||||||
 | 
					    if tensor.dim() == 3:
 | 
				
			||||||
 | 
					        bs, seq_len, dim = tensor.shape
 | 
				
			||||||
 | 
					        padding = torch.zeros((bs, 1, dim), dtype=tensor.dtype, device=tensor.device)
 | 
				
			||||||
 | 
					    elif tensor.dim() == 2:
 | 
				
			||||||
 | 
					        bs, seq_len = tensor.shape
 | 
				
			||||||
 | 
					        padding = torch.zeros((bs, 1), dtype=tensor.dtype, device=tensor.device)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise ValueError('Unsupported tensor shape: {}'.format(tensor.shape))
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    return torch.cat([tensor, padding], dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def apply_transformation(pc, scale, rotation_vector, translation):
 | 
				
			||||||
 | 
					    bs, np, num_points, _ = pc.shape
 | 
				
			||||||
 | 
					    scaled_pc = pc * scale.unsqueeze(2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rotation_matrix = euler_angles_to_matrix(rotation_vector.view(-1, 3), 'XYZ').view(bs, np, 3, 3) # euler tmp
 | 
				
			||||||
 | 
					    rotated_pc = torch.einsum('bnij,bnpj->bnpi', rotation_matrix, scaled_pc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    transformed_pc = rotated_pc + translation.unsqueeze(2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return transformed_pc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def random_sample_pc(pc, max_lens, n_points=10000):
 | 
				
			||||||
 | 
					    bs = max_lens.shape[0]
 | 
				
			||||||
 | 
					    max_len = max_lens.max().item() * n_points
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    random_values = torch.rand(bs, max_len, device=max_lens.device)
 | 
				
			||||||
 | 
					    mask = torch.arange(max_len).expand(bs, max_len).to(max_lens.device) < (max_lens * n_points)
 | 
				
			||||||
 | 
					    masked_random_values = random_values * mask.float()
 | 
				
			||||||
 | 
					    _, indices = torch.topk(masked_random_values, n_points, dim=1)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    return pc[torch.arange(bs).unsqueeze(1), indices]
 | 
				
			||||||
							
								
								
									
										275
									
								
								primitive_anything/utils/__init__.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										275
									
								
								primitive_anything/utils/__init__.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,275 @@
 | 
				
			|||||||
 | 
					from math import ceil
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from beartype.typing import Tuple
 | 
				
			||||||
 | 
					from einops import rearrange, repeat
 | 
				
			||||||
 | 
					from toolz import valmap
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import Tensor
 | 
				
			||||||
 | 
					from torch.nn import Module
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					import yaml
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .typing import typecheck
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def count_parameters(model):
 | 
				
			||||||
 | 
					    return sum(p.numel() for p in model.parameters() if p.requires_grad)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def path_mkdir(path):
 | 
				
			||||||
 | 
					    path = Path(path)
 | 
				
			||||||
 | 
					    path.mkdir(parents=True, exist_ok=True)
 | 
				
			||||||
 | 
					    return path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_yaml(path, default_path=None):
 | 
				
			||||||
 | 
					    path = path_exists(path)
 | 
				
			||||||
 | 
					    with open(path, mode='r') as fp:
 | 
				
			||||||
 | 
					        cfg_s = yaml.load(fp, Loader=yaml.FullLoader)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if default_path is not None:
 | 
				
			||||||
 | 
					        default_path = path_exists(default_path)
 | 
				
			||||||
 | 
					        with open(default_path, mode='r') as fp:
 | 
				
			||||||
 | 
					            cfg = yaml.load(fp, Loader=yaml.FullLoader)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # try current dir default
 | 
				
			||||||
 | 
					        default_path = path.parent / 'default.yml'
 | 
				
			||||||
 | 
					        if default_path.exists():
 | 
				
			||||||
 | 
					            with open(default_path, mode='r') as fp:
 | 
				
			||||||
 | 
					                cfg = yaml.load(fp, Loader=yaml.FullLoader)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            cfg = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    update_recursive(cfg, cfg_s)
 | 
				
			||||||
 | 
					    return cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def dump_yaml(cfg, path):
 | 
				
			||||||
 | 
					    with open(path, mode='w') as f:
 | 
				
			||||||
 | 
					        return yaml.safe_dump(cfg, f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def update_recursive(dict1, dict2):
 | 
				
			||||||
 | 
					    ''' Update two config dictionaries recursively.
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        dict1 (dict): first dictionary to be updated
 | 
				
			||||||
 | 
					        dict2 (dict): second dictionary which entries should be used
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    for k, v in dict2.items():
 | 
				
			||||||
 | 
					        if k not in dict1:
 | 
				
			||||||
 | 
					            dict1[k] = dict()
 | 
				
			||||||
 | 
					        if isinstance(v, dict):
 | 
				
			||||||
 | 
					            update_recursive(dict1[k], v)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            dict1[k] = v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_latest_checkpoint(checkpoint_dir):
 | 
				
			||||||
 | 
					    pattern = re.compile(rf".+\.ckpt\.(\d+)\.pt")
 | 
				
			||||||
 | 
					    max_epoch = -1
 | 
				
			||||||
 | 
					    latest_checkpoint = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for filename in os.listdir(checkpoint_dir):
 | 
				
			||||||
 | 
					        match = pattern.match(filename)
 | 
				
			||||||
 | 
					        if match:
 | 
				
			||||||
 | 
					            num_epoch = int(match.group(1))
 | 
				
			||||||
 | 
					            if num_epoch > max_epoch:
 | 
				
			||||||
 | 
					                max_epoch = num_epoch
 | 
				
			||||||
 | 
					                latest_checkpoint = checkpoint_dir / filename
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not exists(latest_checkpoint):
 | 
				
			||||||
 | 
					        raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    checkpoint = torch.load(latest_checkpoint)
 | 
				
			||||||
 | 
					    return checkpoint, latest_checkpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def torch_to(inp, device, non_blocking=False):
 | 
				
			||||||
 | 
					    nb = non_blocking  # set to True when doing distributed jobs
 | 
				
			||||||
 | 
					    if isinstance(inp, torch.Tensor):
 | 
				
			||||||
 | 
					        return inp.to(device, non_blocking=nb)
 | 
				
			||||||
 | 
					    elif isinstance(inp, (list, tuple)):
 | 
				
			||||||
 | 
					        return type(inp)(map(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp))
 | 
				
			||||||
 | 
					    elif isinstance(inp, dict):
 | 
				
			||||||
 | 
					        return valmap(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# helper functions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def exists(v):
 | 
				
			||||||
 | 
					    return v is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def default(v, d):
 | 
				
			||||||
 | 
					    return v if exists(v) else d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def first(it):
 | 
				
			||||||
 | 
					    return it[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def identity(t, *args, **kwargs):
 | 
				
			||||||
 | 
					    return t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def divisible_by(num, den):
 | 
				
			||||||
 | 
					    return (num % den) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_odd(n):
 | 
				
			||||||
 | 
					    return not divisible_by(n, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_empty(x):
 | 
				
			||||||
 | 
					    return len(x) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_tensor_empty(t: Tensor):
 | 
				
			||||||
 | 
					    return t.numel() == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def set_module_requires_grad_(
 | 
				
			||||||
 | 
					    module: Module,
 | 
				
			||||||
 | 
					    requires_grad: bool
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    for param in module.parameters():
 | 
				
			||||||
 | 
					        param.requires_grad = requires_grad
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def l1norm(t):
 | 
				
			||||||
 | 
					    return F.normalize(t, dim = -1, p = 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def l2norm(t):
 | 
				
			||||||
 | 
					    return F.normalize(t, dim = -1, p = 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def safe_cat(tensors, dim):
 | 
				
			||||||
 | 
					    tensors = [*filter(exists, tensors)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if len(tensors) == 0:
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					    elif len(tensors) == 1:
 | 
				
			||||||
 | 
					        return first(tensors)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return torch.cat(tensors, dim = dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pad_at_dim(t, padding, dim = -1, value = 0):
 | 
				
			||||||
 | 
					    ndim = t.ndim
 | 
				
			||||||
 | 
					    right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1)
 | 
				
			||||||
 | 
					    zeros = (0, 0) * right_dims
 | 
				
			||||||
 | 
					    return F.pad(t, (*zeros, *padding), value = value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pad_to_length(t, length, dim = -1, value = 0, right = True):
 | 
				
			||||||
 | 
					    curr_length = t.shape[dim]
 | 
				
			||||||
 | 
					    remainder = length - curr_length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if remainder <= 0:
 | 
				
			||||||
 | 
					        return t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    padding = (0, remainder) if right else (remainder, 0)
 | 
				
			||||||
 | 
					    return pad_at_dim(t, padding, dim = dim, value = value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def masked_mean(tensor, mask, dim = -1, eps = 1e-5):
 | 
				
			||||||
 | 
					    if not exists(mask):
 | 
				
			||||||
 | 
					        return tensor.mean(dim = dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mask = rearrange(mask, '... -> ... 1')
 | 
				
			||||||
 | 
					    tensor = tensor.masked_fill(~mask, 0.)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    total_el = mask.sum(dim = dim)
 | 
				
			||||||
 | 
					    num = tensor.sum(dim = dim)
 | 
				
			||||||
 | 
					    den = total_el.float().clamp(min = eps)
 | 
				
			||||||
 | 
					    mean = num / den
 | 
				
			||||||
 | 
					    mean = mean.masked_fill(total_el == 0, 0.)
 | 
				
			||||||
 | 
					    return mean
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cycle(dl):
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        for data in dl:
 | 
				
			||||||
 | 
					            yield data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def maybe_del(d: dict, *keys):
 | 
				
			||||||
 | 
					    for key in keys:
 | 
				
			||||||
 | 
					        if key not in d:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        del d[key]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# tensor helper functions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@typecheck
 | 
				
			||||||
 | 
					def discretize(
 | 
				
			||||||
 | 
					    t: Tensor,
 | 
				
			||||||
 | 
					    *,
 | 
				
			||||||
 | 
					    continuous_range: Tuple[float, float],
 | 
				
			||||||
 | 
					    num_discrete: int = 128
 | 
				
			||||||
 | 
					) -> Tensor:
 | 
				
			||||||
 | 
					    lo, hi = continuous_range
 | 
				
			||||||
 | 
					    assert hi > lo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    t = (t - lo) / (hi - lo)
 | 
				
			||||||
 | 
					    t *= num_discrete
 | 
				
			||||||
 | 
					    t -= 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return t.round().long().clamp(min = 0, max = num_discrete - 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@typecheck
 | 
				
			||||||
 | 
					def undiscretize(
 | 
				
			||||||
 | 
					    t: Tensor,
 | 
				
			||||||
 | 
					    *,
 | 
				
			||||||
 | 
					    continuous_range = Tuple[float, float],
 | 
				
			||||||
 | 
					    num_discrete: int = 128
 | 
				
			||||||
 | 
					) -> Tensor:
 | 
				
			||||||
 | 
					    lo, hi = continuous_range
 | 
				
			||||||
 | 
					    assert hi > lo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    t = t.float()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    t += 0.5
 | 
				
			||||||
 | 
					    t /= num_discrete
 | 
				
			||||||
 | 
					    return t * (hi - lo) + lo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@typecheck
 | 
				
			||||||
 | 
					def gaussian_blur_1d(
 | 
				
			||||||
 | 
					    t: Tensor,
 | 
				
			||||||
 | 
					    *,
 | 
				
			||||||
 | 
					    sigma: float = 1.,
 | 
				
			||||||
 | 
					    kernel_size: int = 5
 | 
				
			||||||
 | 
					) -> Tensor:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _, _, channels, device, dtype = *t.shape, t.device, t.dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    width = int(ceil(sigma * kernel_size))
 | 
				
			||||||
 | 
					    width += (width + 1) % 2
 | 
				
			||||||
 | 
					    half_width = width // 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2))
 | 
				
			||||||
 | 
					    gaussian = l1norm(gaussian)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kernel = repeat(gaussian, 'n -> c 1 n', c = channels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    t = rearrange(t, 'b n c -> b c n')
 | 
				
			||||||
 | 
					    out = F.conv1d(t, kernel, padding = half_width, groups = channels)
 | 
				
			||||||
 | 
					    return rearrange(out, 'b c n -> b n c')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@typecheck
 | 
				
			||||||
 | 
					def scatter_mean(
 | 
				
			||||||
 | 
					    tgt: Tensor,
 | 
				
			||||||
 | 
					    indices: Tensor,
 | 
				
			||||||
 | 
					    src = Tensor,
 | 
				
			||||||
 | 
					    *,
 | 
				
			||||||
 | 
					    dim: int = -1,
 | 
				
			||||||
 | 
					    eps: float = 1e-5
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    todo: update to pytorch 2.1 and try https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    num = tgt.scatter_add(dim, indices, src)
 | 
				
			||||||
 | 
					    den = torch.zeros_like(tgt).scatter_add(dim, indices, torch.ones_like(src))
 | 
				
			||||||
 | 
					    return num / den.clamp(min = eps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def prob_mask_like(shape, prob, device):
 | 
				
			||||||
 | 
					    if prob == 1:
 | 
				
			||||||
 | 
					        return torch.ones(shape, device = device, dtype = torch.bool)
 | 
				
			||||||
 | 
					    elif prob == 0:
 | 
				
			||||||
 | 
					        return torch.zeros(shape, device = device, dtype = torch.bool)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
 | 
				
			||||||
							
								
								
									
										65
									
								
								primitive_anything/utils/logger.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										65
									
								
								primitive_anything/utils/logger.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,65 @@
 | 
				
			|||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Verbose:
 | 
				
			||||||
 | 
					    mute = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def print_log(s, logger=None, level='info'):
 | 
				
			||||||
 | 
					    if Verbose.mute:
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if logger is None:
 | 
				
			||||||
 | 
					        logger = logging.getLogger('trainer')
 | 
				
			||||||
 | 
					    if level == 'info':
 | 
				
			||||||
 | 
					        print_info(s)
 | 
				
			||||||
 | 
					        logger.info(s)
 | 
				
			||||||
 | 
					    elif level == 'warning':
 | 
				
			||||||
 | 
					        print_warning(s)
 | 
				
			||||||
 | 
					        logger.warning(s)
 | 
				
			||||||
 | 
					    elif level == 'error':
 | 
				
			||||||
 | 
					        print_error(s)
 | 
				
			||||||
 | 
					        logger.error(s)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_logger(log_dir, name='trainer'):
 | 
				
			||||||
 | 
					    assert os.path.exists(log_dir), 'log_dir {} does not exist.'
 | 
				
			||||||
 | 
					    logger = logging.getLogger(name)
 | 
				
			||||||
 | 
					    file_path = log_dir / '{}.log'.format(name)
 | 
				
			||||||
 | 
					    hdlr = logging.FileHandler(file_path)
 | 
				
			||||||
 | 
					    formatter = logging.Formatter('[%(asctime)s] %(levelname)s: %(message)s')
 | 
				
			||||||
 | 
					    hdlr.setFormatter(formatter)
 | 
				
			||||||
 | 
					    logger.addHandler(hdlr)
 | 
				
			||||||
 | 
					    logger.setLevel(logging.INFO)
 | 
				
			||||||
 | 
					    return logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TerminalColors:
 | 
				
			||||||
 | 
					    HEADER = '\033[95m'
 | 
				
			||||||
 | 
					    OKBLUE = '\033[94m'
 | 
				
			||||||
 | 
					    OKGREEN = '\033[92m'
 | 
				
			||||||
 | 
					    WARNING = '\033[93m'
 | 
				
			||||||
 | 
					    FAIL = '\033[91m'
 | 
				
			||||||
 | 
					    ENDC = '\033[0m'
 | 
				
			||||||
 | 
					    BOLD = '\033[1m'
 | 
				
			||||||
 | 
					    UNDERLINE = '\033[4m'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_time():
 | 
				
			||||||
 | 
					    return time.strftime('%Y-%m-%d %H:%M:%S')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def print_info(s):
 | 
				
			||||||
 | 
					    print(TerminalColors.OKBLUE + '[' + get_time() + '] ' + str(s) + TerminalColors.ENDC)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def print_warning(s):
 | 
				
			||||||
 | 
					    print(TerminalColors.WARNING + '[' + get_time() + '] WARN ' + str(s) + TerminalColors.ENDC)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def print_error(s):
 | 
				
			||||||
 | 
					    print(TerminalColors.FAIL + '[' + get_time() + '] ERROR ' + str(s) + TerminalColors.ENDC)
 | 
				
			||||||
							
								
								
									
										57
									
								
								primitive_anything/utils/typing.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										57
									
								
								primitive_anything/utils/typing.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,57 @@
 | 
				
			|||||||
 | 
					from environs import Env
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from torch import Tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from beartype import beartype
 | 
				
			||||||
 | 
					from beartype.door import is_bearable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from jaxtyping import (
 | 
				
			||||||
 | 
					    Float,
 | 
				
			||||||
 | 
					    Int,
 | 
				
			||||||
 | 
					    Bool,
 | 
				
			||||||
 | 
					    jaxtyped
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# environment
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					env = Env()
 | 
				
			||||||
 | 
					env.read_env()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def always(value):
 | 
				
			||||||
 | 
					    def inner(*args, **kwargs):
 | 
				
			||||||
 | 
					        return value
 | 
				
			||||||
 | 
					    return inner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def identity(t):
 | 
				
			||||||
 | 
					    return t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# jaxtyping is a misnomer, works for pytorch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TorchTyping:
 | 
				
			||||||
 | 
					    def __init__(self, abstract_dtype):
 | 
				
			||||||
 | 
					        self.abstract_dtype = abstract_dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, shapes: str):
 | 
				
			||||||
 | 
					        return self.abstract_dtype[Tensor, shapes]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Float = TorchTyping(Float)
 | 
				
			||||||
 | 
					Int   = TorchTyping(Int)
 | 
				
			||||||
 | 
					Bool  = TorchTyping(Bool)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# use env variable TYPECHECK to control whether to use beartype + jaxtyping
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					should_typecheck = env.bool('TYPECHECK', False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					beartype_isinstance = is_bearable if should_typecheck else always(True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = [
 | 
				
			||||||
 | 
					    Float,
 | 
				
			||||||
 | 
					    Int,
 | 
				
			||||||
 | 
					    Bool,
 | 
				
			||||||
 | 
					    typecheck,
 | 
				
			||||||
 | 
					    beartype_isinstance
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
							
								
								
									
										21
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					accelerate
 | 
				
			||||||
 | 
					beartype
 | 
				
			||||||
 | 
					einops
 | 
				
			||||||
 | 
					gateloop_transformer
 | 
				
			||||||
 | 
					matplotlib
 | 
				
			||||||
 | 
					pytorch_custom_utils
 | 
				
			||||||
 | 
					x_transformers
 | 
				
			||||||
 | 
					toolz
 | 
				
			||||||
 | 
					environs
 | 
				
			||||||
 | 
					jaxtyping
 | 
				
			||||||
 | 
					omegaconf
 | 
				
			||||||
 | 
					transformers
 | 
				
			||||||
 | 
					open3d
 | 
				
			||||||
 | 
					trimesh
 | 
				
			||||||
 | 
					pytorch_lightning
 | 
				
			||||||
 | 
					scikit-image
 | 
				
			||||||
 | 
					opencv-python
 | 
				
			||||||
 | 
					mesh2sdf
 | 
				
			||||||
 | 
					seaborn
 | 
				
			||||||
 | 
					mesh_to_sdf
 | 
				
			||||||
 | 
					point_cloud_utils
 | 
				
			||||||
							
								
								
									
										104
									
								
								sample.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								sample.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,104 @@
 | 
				
			|||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					import glob
 | 
				
			||||||
 | 
					import multiprocessing
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mesh_to_sdf import get_surface_point_cloud
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import open3d as o3d
 | 
				
			||||||
 | 
					import trimesh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					os.environ["PYOPENGL_PLATFORM"] = "egl"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def sample_surface_points(mesh, number_of_points=500000, surface_point_method="scan", sign_method="normal",
 | 
				
			||||||
 | 
					                          scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False,
 | 
				
			||||||
 | 
					                          return_surface_pc_normals=False, normalized=False):
 | 
				
			||||||
 | 
					    sample_start = time.time()
 | 
				
			||||||
 | 
					    if surface_point_method == "sample" and sign_method == "depth":
 | 
				
			||||||
 | 
					        print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.")
 | 
				
			||||||
 | 
					        sign_method = "normal"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    surface_start = time.time()
 | 
				
			||||||
 | 
					    bound_radius = 1 if normalized else None
 | 
				
			||||||
 | 
					    surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution,
 | 
				
			||||||
 | 
					                                                  sample_point_count,
 | 
				
			||||||
 | 
					                                                  calculate_normals=sign_method == "normal" or return_gradients)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    surface_end = time.time()
 | 
				
			||||||
 | 
					    print("surface point cloud time cost :", surface_end - surface_start)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    normal_start = time.time()
 | 
				
			||||||
 | 
					    if return_surface_pc_normals:
 | 
				
			||||||
 | 
					        rng = np.random.default_rng()
 | 
				
			||||||
 | 
					        assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0]
 | 
				
			||||||
 | 
					        indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True)
 | 
				
			||||||
 | 
					        points = surface_point_cloud.points[indices]
 | 
				
			||||||
 | 
					        normals = surface_point_cloud.normals[indices]
 | 
				
			||||||
 | 
					        surface_points = np.concatenate([points, normals], axis=-1)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True)
 | 
				
			||||||
 | 
					    normal_end = time.time()
 | 
				
			||||||
 | 
					    print("normal time cost :", normal_end - normal_start)
 | 
				
			||||||
 | 
					    sample_end = time.time()
 | 
				
			||||||
 | 
					    print("sample surface point time cost :", sample_end - sample_start)
 | 
				
			||||||
 | 
					    return surface_points
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def process_surface_point(mesh, number_of_near_surface_points, return_surface_pc_normals=False):
 | 
				
			||||||
 | 
					    mesh = trimesh.load(mesh, force="mesh")
 | 
				
			||||||
 | 
					    surface_point = sample_surface_points(mesh, number_of_near_surface_points, return_surface_pc_normals=return_surface_pc_normals)
 | 
				
			||||||
 | 
					    return surface_point
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def sample_model(model_path, num_points, return_surface_pc_normals=True):
 | 
				
			||||||
 | 
					    pc_out_path = os.path.join(args.output_dir, os.path.basename(model_path)).replace(f".{args.postfix}", ".ply")
 | 
				
			||||||
 | 
					    if os.path.exists(pc_out_path):
 | 
				
			||||||
 | 
					        print(f"{pc_out_path}: exists!")
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        surface_point = process_surface_point(model_path, num_points, return_surface_pc_normals=return_surface_pc_normals)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        coords = surface_point[:, :3]
 | 
				
			||||||
 | 
					        normals = surface_point[:, 3:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert (np.linalg.norm(np.asarray(normals), axis=-1) > 0.99).all()
 | 
				
			||||||
 | 
					        assert (np.linalg.norm(np.asarray(normals), axis=-1) < 1.01).all()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pcd = o3d.geometry.PointCloud()
 | 
				
			||||||
 | 
					        pcd.points = o3d.utility.Vector3dVector(coords)
 | 
				
			||||||
 | 
					        pcd.colors = o3d.utility.Vector3dVector(np.ones_like(coords)*0.5)
 | 
				
			||||||
 | 
					        pcd.normals = o3d.utility.Vector3dVector(normals)
 | 
				
			||||||
 | 
					        o3d.io.write_point_cloud(pc_out_path, pcd)
 | 
				
			||||||
 | 
					        print(f"write_point_cloud: {pc_out_path}")
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					        print(f"[ERROR] file: {pc_out_path}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser.add_argument("--input_dir", type=str, default="./results/infer/JsonResults")
 | 
				
			||||||
 | 
					    parser.add_argument("--output_dir", type=str, default="./results/infer/PointClouds")
 | 
				
			||||||
 | 
					    parser.add_argument("--num_points", type=int, default=10000)
 | 
				
			||||||
 | 
					    parser.add_argument("--postfix", type=str, default="glb")
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not os.path.exists(args.input_dir):
 | 
				
			||||||
 | 
					        print("Invalid input!")
 | 
				
			||||||
 | 
					        exit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if os.path.exists(args.output_dir):
 | 
				
			||||||
 | 
					        print(f"path: {args.output_dir} exists!")
 | 
				
			||||||
 | 
					        # exit(1)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        os.makedirs(args.output_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_prefix = os.path.join(args.input_dir, f"*.{args.postfix}")
 | 
				
			||||||
 | 
					    model_path_list = sorted(list(glob.glob(model_prefix)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sample_model_func = partial(sample_model, num_points=args.num_points, return_surface_pc_normals=True)
 | 
				
			||||||
 | 
					    with multiprocessing.Pool(16) as pool:
 | 
				
			||||||
 | 
					        pool.map(sample_model_func, model_path_list)
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user