mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Marching Cubes cuda extension
Summary: Torch CUDA extension for Marching Cubes - MC involving 3 steps: - 1st forward pass to collect vertices and occupied state for each voxel - Compute compactVoxelArray to skip non-empty voxels - 2nd pass to genereate interpolated vertex positions and faces by marching through the grid - In contrast to existing MC: - Bind each interpolated vertex with a global edge_id to address floating-point precision - Added deduplication process to remove redundant vertices and faces Benchmarks (ms): | N / V(^3) | python | C++ | CUDA | Speedup | | 2 / 20 | 12176873 | 24338 | 4363 | 2790x/5x| | 1 / 100 | - | 3070511 | 27126 | 113x | | 2 / 100 | - | 5968934 | 53129 | 112x | | 1 / 256 | - | 61278092 | 430900 | 142x | | 2 / 256 | - |125687930 | 856941 | 146x | Reviewed By: kjchalup Differential Revision: D39644248 fbshipit-source-id: d679c0c79d67b98b235d12296f383d760a00042a
This commit is contained in:
parent
9a0b0c2e74
commit
8b8291830e
559
pytorch3d/csrc/marching_cubes/marching_cubes.cu
Normal file
559
pytorch3d/csrc/marching_cubes/marching_cubes.cu
Normal file
@ -0,0 +1,559 @@
|
||||
/*
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/scan.h>
|
||||
#include <cstdio>
|
||||
#include "marching_cubes/tables.h"
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
/*
|
||||
Parallelized marching cubes for pytorch extension
|
||||
referenced and adapted from CUDA-Samples:
|
||||
(https://github.com/NVIDIA/cuda-samples/tree/master/Samples/5_Domain_Specific/marchingCubes)
|
||||
We divide the algorithm into two forward-passes:
|
||||
(1) The first forward-pass executes "ClassifyVoxelKernel" to
|
||||
evaluate volume scalar field for each cube and pre-compute
|
||||
two arrays -- number of vertices per cube (d_voxelVerts) and
|
||||
occupied or not per cube (d_voxelOccupied).
|
||||
|
||||
Some prepration steps:
|
||||
With d_voxelOccupied, an exclusive scan is performed to compute
|
||||
the number of activeVoxels, which can be used to accelerate
|
||||
computation. With d_voxelVerts, another exclusive scan
|
||||
is performed to compute the accumulated sum of vertices in the 3d
|
||||
grid and totalVerts.
|
||||
|
||||
(2) The second forward-pass calls "GenerateFacesKernel" to
|
||||
generate interpolated vertex positions and face indices by "marching
|
||||
through" each cube in the grid.
|
||||
|
||||
*/
|
||||
|
||||
// EPS: Used to indicate if two float values are close
|
||||
__constant__ const float EPSILON = 1e-5;
|
||||
|
||||
// Thrust wrapper for exclusive scan
|
||||
//
|
||||
// Args:
|
||||
// output: pointer to on-device output array
|
||||
// input: pointer to on-device input array, where scan is performed
|
||||
// numElements: number of elements for the input array
|
||||
//
|
||||
void ThrustScanWrapper(int* output, int* input, int numElements) {
|
||||
thrust::exclusive_scan(
|
||||
thrust::device_ptr<int>(input),
|
||||
thrust::device_ptr<int>(input + numElements),
|
||||
thrust::device_ptr<int>(output));
|
||||
}
|
||||
|
||||
// Linearly interpolate the position where an isosurface cuts an edge
|
||||
// between two vertices, based on their scalar values
|
||||
//
|
||||
// Args:
|
||||
// isolevel: float value used as threshold
|
||||
// p1: position of point1
|
||||
// p2: position of point2
|
||||
// valp1: field value for p1
|
||||
// valp2: field value for p2
|
||||
//
|
||||
// Returns:
|
||||
// point: interpolated verte
|
||||
//
|
||||
__device__ float3
|
||||
vertexInterp(float isolevel, float3 p1, float3 p2, float valp1, float valp2) {
|
||||
float ratio;
|
||||
float3 p;
|
||||
|
||||
if (abs(isolevel - valp1) < EPSILON) {
|
||||
return p1;
|
||||
} else if (abs(isolevel - valp2) < EPSILON) {
|
||||
return p2;
|
||||
} else if (abs(valp1 - valp2) < EPSILON) {
|
||||
return p1;
|
||||
}
|
||||
|
||||
ratio = (isolevel - valp1) / (valp2 - valp1);
|
||||
|
||||
p.x = p1.x * (1 - ratio) + p2.x * ratio;
|
||||
p.y = p1.y * (1 - ratio) + p2.y * ratio;
|
||||
p.z = p1.z * (1 - ratio) + p2.z * ratio;
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
// Determine if the triangle is degenerate
|
||||
// A triangle is degenerate when at least two of the vertices
|
||||
// share the same position.
|
||||
//
|
||||
// Args:
|
||||
// p1: position of vertex p1
|
||||
// p2: position of vertex p2
|
||||
// p3: position of vertex p3
|
||||
//
|
||||
// Returns:
|
||||
// boolean indicator if the triangle is degenerate
|
||||
__device__ bool isDegenerate(float3 p1, float3 p2, float3 p3) {
|
||||
if ((abs(p1.x - p2.x) < EPSILON && abs(p1.y - p2.y) < EPSILON &&
|
||||
abs(p1.z - p2.z) < EPSILON) ||
|
||||
(abs(p2.x - p3.x) < EPSILON && abs(p2.y - p3.y) < EPSILON &&
|
||||
abs(p2.z - p3.z) < EPSILON) ||
|
||||
(abs(p3.x - p1.x) < EPSILON && abs(p3.y - p1.y) < EPSILON &&
|
||||
abs(p3.z - p1.z) < EPSILON)) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert from local vertex id to global vertex id, given position
|
||||
// of the cube where the vertex resides. The function ensures vertices
|
||||
// shared from adjacent cubes are mapped to the same global id.
|
||||
|
||||
// Args:
|
||||
// v: local vertex id
|
||||
// x: x position of the cube where the vertex belongs
|
||||
// y: y position of the cube where the vertex belongs
|
||||
// z: z position of the cube where the vertex belongs
|
||||
// W: width of x dimension
|
||||
// H: height of y dimension
|
||||
|
||||
// Returns:
|
||||
// global vertex id represented by its x/y/z offsets
|
||||
__device__ uint localToGlobal(int v, int x, int y, int z, int W, int H) {
|
||||
const int dx = v & 1;
|
||||
const int dy = v >> 1 & 1;
|
||||
const int dz = v >> 2 & 1;
|
||||
return (x + dx) + (y + dy) * W + (z + dz) * W * H;
|
||||
}
|
||||
|
||||
// Hash_combine a pair of global vertex id to a single integer.
|
||||
//
|
||||
// Args:
|
||||
// v1_id: global id of vertex 1
|
||||
// v2_id: global id of vertex 2
|
||||
// W: width of the 3d grid
|
||||
// H: height of the 3d grid
|
||||
// Z: depth of the 3d grid
|
||||
//
|
||||
// Returns:
|
||||
// hashing for a pair of vertex ids
|
||||
//
|
||||
__device__ int64_t hashVpair(uint v1_id, uint v2_id, int W, int H, int D) {
|
||||
return (int64_t)v1_id * (W + W * H + W * H * D) + (int64_t)v2_id;
|
||||
}
|
||||
|
||||
// precompute number of vertices and occupancy
|
||||
// for each voxel in the grid.
|
||||
//
|
||||
// Args:
|
||||
// voxelVerts: pointer to device array to store number
|
||||
// of verts per voxel
|
||||
// voxelOccupied: pointer to device array to store
|
||||
// occupancy state per voxel
|
||||
// vol: torch tensor stored with 3D scalar field
|
||||
// isolevel: threshold to determine isosurface intersection
|
||||
//
|
||||
__global__ void ClassifyVoxelKernel(
|
||||
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> voxelVerts,
|
||||
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> voxelOccupied,
|
||||
const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
|
||||
// const at::PackedTensorAccessor<int, 1, at::RestrictPtrTraits>
|
||||
// numVertsTable,
|
||||
float isolevel) {
|
||||
const int indexTable[8]{0, 1, 4, 5, 3, 2, 7, 6};
|
||||
const uint D = vol.size(0) - 1;
|
||||
const uint H = vol.size(1) - 1;
|
||||
const uint W = vol.size(2) - 1;
|
||||
|
||||
// 1-d grid
|
||||
uint id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint num_threads = gridDim.x * blockDim.x;
|
||||
|
||||
// Table mapping from cubeindex to number of vertices in the configuration
|
||||
const unsigned char numVertsTable[256] = {
|
||||
0, 3, 3, 6, 3, 6, 6, 9, 3, 6, 6, 9, 6, 9, 9, 6, 3, 6,
|
||||
6, 9, 6, 9, 9, 12, 6, 9, 9, 12, 9, 12, 12, 9, 3, 6, 6, 9,
|
||||
6, 9, 9, 12, 6, 9, 9, 12, 9, 12, 12, 9, 6, 9, 9, 6, 9, 12,
|
||||
12, 9, 9, 12, 12, 9, 12, 15, 15, 6, 3, 6, 6, 9, 6, 9, 9, 12,
|
||||
6, 9, 9, 12, 9, 12, 12, 9, 6, 9, 9, 12, 9, 12, 12, 15, 9, 12,
|
||||
12, 15, 12, 15, 15, 12, 6, 9, 9, 12, 9, 12, 6, 9, 9, 12, 12, 15,
|
||||
12, 15, 9, 6, 9, 12, 12, 9, 12, 15, 9, 6, 12, 15, 15, 12, 15, 6,
|
||||
12, 3, 3, 6, 6, 9, 6, 9, 9, 12, 6, 9, 9, 12, 9, 12, 12, 9,
|
||||
6, 9, 9, 12, 9, 12, 12, 15, 9, 6, 12, 9, 12, 9, 15, 6, 6, 9,
|
||||
9, 12, 9, 12, 12, 15, 9, 12, 12, 15, 12, 15, 15, 12, 9, 12, 12, 9,
|
||||
12, 15, 15, 12, 12, 9, 15, 6, 15, 12, 6, 3, 6, 9, 9, 12, 9, 12,
|
||||
12, 15, 9, 12, 12, 15, 6, 9, 9, 6, 9, 12, 12, 15, 12, 15, 15, 6,
|
||||
12, 9, 15, 12, 9, 6, 12, 3, 9, 12, 12, 15, 12, 15, 9, 12, 12, 15,
|
||||
15, 6, 9, 12, 6, 3, 6, 9, 9, 6, 9, 12, 6, 3, 9, 6, 12, 3,
|
||||
6, 3, 3, 0,
|
||||
};
|
||||
|
||||
for (uint tid = id; tid < D * H * W; tid += num_threads) {
|
||||
// compute global location of the voxel
|
||||
const int gx = tid % W;
|
||||
const int gy = tid / W % H;
|
||||
const int gz = tid / (W * H);
|
||||
|
||||
int cubeindex = 0;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const int dx = i & 1;
|
||||
const int dy = i >> 1 & 1;
|
||||
const int dz = i >> 2 & 1;
|
||||
|
||||
const int x = gx + dx;
|
||||
const int y = gy + dy;
|
||||
const int z = gz + dz;
|
||||
|
||||
if (vol[z][y][x] < isolevel) {
|
||||
cubeindex |= 1 << indexTable[i];
|
||||
}
|
||||
}
|
||||
// collect number of vertices for each voxel
|
||||
unsigned char numVerts = numVertsTable[cubeindex];
|
||||
voxelVerts[tid] = numVerts;
|
||||
voxelOccupied[tid] = (numVerts > 0);
|
||||
}
|
||||
}
|
||||
|
||||
// extract compact voxel array for acceleration
|
||||
//
|
||||
// Args:
|
||||
// compactedVoxelArray: tensor of shape (activeVoxels,) which maps
|
||||
// from accumulated non-empty voxel index to original 3d grid index
|
||||
// voxelOccupied: tensor of shape (numVoxels,) which stores
|
||||
// the occupancy state per voxel
|
||||
// voxelOccupiedScan: tensor of shape (numVoxels,) which
|
||||
// stores the accumulated occupied voxel counts
|
||||
// numVoxels: number of total voxels in the grid
|
||||
//
|
||||
__global__ void CompactVoxelsKernel(
|
||||
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
|
||||
compactedVoxelArray,
|
||||
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
|
||||
voxelOccupied,
|
||||
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
|
||||
voxelOccupiedScan,
|
||||
uint numVoxels) {
|
||||
uint id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint num_threads = gridDim.x * blockDim.x;
|
||||
for (uint tid = id; tid < numVoxels; tid += num_threads) {
|
||||
if (voxelOccupied[tid]) {
|
||||
compactedVoxelArray[voxelOccupiedScan[tid]] = tid;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generate triangles for each voxel using marching cubes
|
||||
//
|
||||
// Args:
|
||||
// verts: torch tensor of shape (V, 3) to store interpolated mesh vertices
|
||||
// faces: torch tensor of shape (F, 3) to store indices for mesh faces
|
||||
// ids: torch tensor of shape (V) to store id of each vertex
|
||||
// compactedVoxelArray: tensor of shape (activeVoxels,) which stores
|
||||
// non-empty voxel index.
|
||||
// numVertsScanned: tensor of shape (numVoxels,) which stores accumulated
|
||||
// vertices count in the voxel
|
||||
// activeVoxels: number of active voxels used for acceleration
|
||||
// vol: torch tensor stored with 3D scalar field
|
||||
// isolevel: threshold to determine isosurface intersection
|
||||
//
|
||||
__global__ void GenerateFacesKernel(
|
||||
torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> verts,
|
||||
torch::PackedTensorAccessor<int64_t, 2, torch::RestrictPtrTraits> faces,
|
||||
torch::PackedTensorAccessor<int64_t, 1, torch::RestrictPtrTraits> ids,
|
||||
torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits>
|
||||
compactedVoxelArray,
|
||||
torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits>
|
||||
numVertsScanned,
|
||||
const uint activeVoxels,
|
||||
const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
|
||||
const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
|
||||
// const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
|
||||
// numVertsTable,
|
||||
const float isolevel) {
|
||||
uint id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint num_threads = gridDim.x * blockDim.x;
|
||||
const int faces_size = faces.size(0);
|
||||
// Table mapping each edge to the corresponding cube vertices offsets
|
||||
const int edgeToVertsTable[12][2] = {
|
||||
{0, 1},
|
||||
{1, 5},
|
||||
{4, 5},
|
||||
{0, 4},
|
||||
{2, 3},
|
||||
{3, 7},
|
||||
{6, 7},
|
||||
{2, 6},
|
||||
{0, 2},
|
||||
{1, 3},
|
||||
{5, 7},
|
||||
{4, 6},
|
||||
};
|
||||
|
||||
// Table mapping from cubeindex to number of vertices in the configuration
|
||||
const unsigned char numVertsTable[256] = {
|
||||
0, 3, 3, 6, 3, 6, 6, 9, 3, 6, 6, 9, 6, 9, 9, 6, 3, 6,
|
||||
6, 9, 6, 9, 9, 12, 6, 9, 9, 12, 9, 12, 12, 9, 3, 6, 6, 9,
|
||||
6, 9, 9, 12, 6, 9, 9, 12, 9, 12, 12, 9, 6, 9, 9, 6, 9, 12,
|
||||
12, 9, 9, 12, 12, 9, 12, 15, 15, 6, 3, 6, 6, 9, 6, 9, 9, 12,
|
||||
6, 9, 9, 12, 9, 12, 12, 9, 6, 9, 9, 12, 9, 12, 12, 15, 9, 12,
|
||||
12, 15, 12, 15, 15, 12, 6, 9, 9, 12, 9, 12, 6, 9, 9, 12, 12, 15,
|
||||
12, 15, 9, 6, 9, 12, 12, 9, 12, 15, 9, 6, 12, 15, 15, 12, 15, 6,
|
||||
12, 3, 3, 6, 6, 9, 6, 9, 9, 12, 6, 9, 9, 12, 9, 12, 12, 9,
|
||||
6, 9, 9, 12, 9, 12, 12, 15, 9, 6, 12, 9, 12, 9, 15, 6, 6, 9,
|
||||
9, 12, 9, 12, 12, 15, 9, 12, 12, 15, 12, 15, 15, 12, 9, 12, 12, 9,
|
||||
12, 15, 15, 12, 12, 9, 15, 6, 15, 12, 6, 3, 6, 9, 9, 12, 9, 12,
|
||||
12, 15, 9, 12, 12, 15, 6, 9, 9, 6, 9, 12, 12, 15, 12, 15, 15, 6,
|
||||
12, 9, 15, 12, 9, 6, 12, 3, 9, 12, 12, 15, 12, 15, 9, 12, 12, 15,
|
||||
15, 6, 9, 12, 6, 3, 6, 9, 9, 6, 9, 12, 6, 3, 9, 6, 12, 3,
|
||||
6, 3, 3, 0,
|
||||
};
|
||||
|
||||
for (uint tid = id; tid < activeVoxels; tid += num_threads) {
|
||||
uint voxel = compactedVoxelArray[tid]; // maps from accumulated id to
|
||||
// original 3d voxel id
|
||||
// mapping from offsets to vi index
|
||||
int indexTable[8]{0, 1, 4, 5, 3, 2, 7, 6};
|
||||
// field value for each vertex
|
||||
float val[8];
|
||||
// position for each vertex
|
||||
float3 p[8];
|
||||
// 3d address
|
||||
const uint D = vol.size(0) - 1;
|
||||
const uint H = vol.size(1) - 1;
|
||||
const uint W = vol.size(2) - 1;
|
||||
|
||||
const int gx = voxel % W;
|
||||
const int gy = voxel / W % H;
|
||||
const int gz = voxel / (W * H);
|
||||
|
||||
// recalculate cubeindex;
|
||||
uint cubeindex = 0;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const int dx = i & 1;
|
||||
const int dy = i >> 1 & 1;
|
||||
const int dz = i >> 2 & 1;
|
||||
|
||||
const int x = gx + dx;
|
||||
const int y = gy + dy;
|
||||
const int z = gz + dz;
|
||||
|
||||
if (vol[z][y][x] < isolevel) {
|
||||
cubeindex |= 1 << indexTable[i];
|
||||
}
|
||||
val[indexTable[i]] = vol[z][y][x]; // maps from vi to volume
|
||||
p[indexTable[i]] = make_float3(x, y, z); // maps from vi to position
|
||||
}
|
||||
|
||||
// Interpolate vertices where the surface intersects the cube
|
||||
float3 vertlist[12];
|
||||
vertlist[0] = vertexInterp(isolevel, p[0], p[1], val[0], val[1]);
|
||||
vertlist[1] = vertexInterp(isolevel, p[1], p[2], val[1], val[2]);
|
||||
vertlist[2] = vertexInterp(isolevel, p[3], p[2], val[3], val[2]);
|
||||
vertlist[3] = vertexInterp(isolevel, p[0], p[3], val[0], val[3]);
|
||||
|
||||
vertlist[4] = vertexInterp(isolevel, p[4], p[5], val[4], val[5]);
|
||||
vertlist[5] = vertexInterp(isolevel, p[5], p[6], val[5], val[6]);
|
||||
vertlist[6] = vertexInterp(isolevel, p[7], p[6], val[7], val[6]);
|
||||
vertlist[7] = vertexInterp(isolevel, p[4], p[7], val[4], val[7]);
|
||||
|
||||
vertlist[8] = vertexInterp(isolevel, p[0], p[4], val[0], val[4]);
|
||||
vertlist[9] = vertexInterp(isolevel, p[1], p[5], val[1], val[5]);
|
||||
vertlist[10] = vertexInterp(isolevel, p[2], p[6], val[2], val[6]);
|
||||
vertlist[11] = vertexInterp(isolevel, p[3], p[7], val[3], val[7]);
|
||||
|
||||
// output triangle faces
|
||||
uint numVerts = numVertsTable[cubeindex];
|
||||
|
||||
for (int i = 0; i < numVerts; i++) {
|
||||
int index = numVertsScanned[voxel] + i;
|
||||
unsigned char edge = faceTable[cubeindex][i];
|
||||
|
||||
uint v1 = edgeToVertsTable[edge][0];
|
||||
uint v2 = edgeToVertsTable[edge][1];
|
||||
uint v1_id = localToGlobal(v1, gx, gy, gz, W + 1, H + 1);
|
||||
uint v2_id = localToGlobal(v2, gx, gy, gz, W + 1, H + 1);
|
||||
int64_t edge_id = hashVpair(v1_id, v2_id, W + 1, H + 1, D + 1);
|
||||
|
||||
verts[index][0] = vertlist[edge].x;
|
||||
verts[index][1] = vertlist[edge].y;
|
||||
verts[index][2] = vertlist[edge].z;
|
||||
|
||||
if (index < faces_size) {
|
||||
faces[index][0] = index * 3 + 0;
|
||||
faces[index][1] = index * 3 + 1;
|
||||
faces[index][2] = index * 3 + 2;
|
||||
}
|
||||
|
||||
ids[index] = edge_id;
|
||||
}
|
||||
} // end for grid-strided kernel
|
||||
}
|
||||
|
||||
// Entrance for marching cubes cuda extension. Marching Cubes is an algorithm to
|
||||
// create triangle meshes from an implicit function (one of the form f(x, y, z)
|
||||
// = 0). It works by iteratively checking a grid of cubes superimposed over a
|
||||
// region of the function. The number of faces and positions of the vertices in
|
||||
// each cube are determined by the the isolevel as well as the volume values
|
||||
// from the eight vertices of the cube.
|
||||
//
|
||||
// We implement this algorithm with two forward passes where the first pass
|
||||
// checks the occupancy and collects number of vertices for each cube. The
|
||||
// second pass will skip empty voxels and generate vertices as well as faces for
|
||||
// each cube through table lookup. The vertex positions, faces and identifiers
|
||||
// for each vertex will be returned.
|
||||
//
|
||||
//
|
||||
// Args:
|
||||
// vol: torch tensor of shape (D, H, W) for volume scalar field
|
||||
// isolevel: threshold to determine isosurface intesection
|
||||
//
|
||||
// Returns:
|
||||
// tuple of <verts, faces, ids>: which stores vertex positions, face
|
||||
// indices and integer identifiers for each vertex.
|
||||
// verts: (N_verts, 3) FloatTensor for vertex positions
|
||||
// faces: (N_faces, 3) LongTensor of face indices
|
||||
// ids: (N_verts,) LongTensor used to identify each vertex. Vertices from
|
||||
// adjacent edges can share the same 3d position. To reduce memory
|
||||
// redudancy, we tag each vertex with a unique id for deduplication. In
|
||||
// contrast to deduping on vertices, this has the benefit to avoid
|
||||
// floating point precision issues.
|
||||
//
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
|
||||
const at::Tensor& vol,
|
||||
const float isolevel) {
|
||||
// Set the device for the kernel launch based on the device of vol
|
||||
at::cuda::CUDAGuard device_guard(vol.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// transfer _FACE_TABLE data to device
|
||||
torch::Tensor face_table_tensor = torch::zeros(
|
||||
{256, 16}, torch::TensorOptions().dtype(at::kInt).device(at::kCPU));
|
||||
auto face_table_a = face_table_tensor.accessor<int, 2>();
|
||||
for (int i = 0; i < 256; i++) {
|
||||
for (int j = 0; j < 16; j++) {
|
||||
face_table_a[i][j] = _FACE_TABLE[i][j];
|
||||
}
|
||||
}
|
||||
torch::Tensor faceTable = face_table_tensor.to(vol.device());
|
||||
|
||||
// get numVoxels
|
||||
int threads = 128;
|
||||
const uint D = vol.size(0);
|
||||
const uint H = vol.size(1);
|
||||
const uint W = vol.size(2);
|
||||
const int numVoxels = (D - 1) * (H - 1) * (W - 1);
|
||||
dim3 grid((numVoxels + threads - 1) / threads, 1, 1);
|
||||
if (grid.x > 65535) {
|
||||
grid.x = 65535;
|
||||
}
|
||||
|
||||
auto d_voxelVerts =
|
||||
torch::zeros({numVoxels}, torch::TensorOptions().dtype(at::kInt))
|
||||
.to(vol.device());
|
||||
auto d_voxelOccupied =
|
||||
torch::zeros({numVoxels}, torch::TensorOptions().dtype(at::kInt))
|
||||
.to(vol.device());
|
||||
|
||||
// Execute "ClassifyVoxelKernel" kernel to precompute
|
||||
// two arrays - d_voxelOccupied and d_voxelVertices to global memory,
|
||||
// which stores the occupancy state and number of voxel vertices per voxel.
|
||||
ClassifyVoxelKernel<<<grid, threads, 0, stream>>>(
|
||||
d_voxelVerts.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
|
||||
isolevel);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Scan "d_voxelOccupied" array to generate accumulated voxel occupancy
|
||||
// count for voxels in the grid and compute the number of active voxels.
|
||||
// If the number of active voxels is 0, return zero tensor for verts and
|
||||
// faces.
|
||||
auto d_voxelOccupiedScan =
|
||||
torch::zeros({numVoxels}, torch::TensorOptions().dtype(at::kInt))
|
||||
.to(vol.device());
|
||||
ThrustScanWrapper(
|
||||
d_voxelOccupiedScan.data_ptr<int>(),
|
||||
d_voxelOccupied.data_ptr<int>(),
|
||||
numVoxels);
|
||||
|
||||
// number of active voxels
|
||||
int lastElement = d_voxelVerts[numVoxels - 1].cpu().item<int>();
|
||||
int lastScan = d_voxelOccupiedScan[numVoxels - 1].cpu().item<int>();
|
||||
int activeVoxels = lastElement + lastScan;
|
||||
|
||||
const int device_id = vol.device().index();
|
||||
auto opt =
|
||||
torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA, device_id);
|
||||
auto opt_long = torch::TensorOptions()
|
||||
.dtype(torch::kInt64)
|
||||
.device(torch::kCUDA, device_id);
|
||||
|
||||
if (activeVoxels == 0) {
|
||||
int ntris = 0;
|
||||
torch::Tensor verts = torch::zeros({ntris * 3, 3}, vol.options());
|
||||
torch::Tensor faces = torch::zeros({ntris, 3}, opt_long);
|
||||
torch::Tensor ids = torch::zeros({ntris}, opt_long);
|
||||
return std::make_tuple(verts, faces, ids);
|
||||
}
|
||||
|
||||
// Execute "CompactVoxelsKernel" kernel to compress voxels for accleration.
|
||||
// This allows us to run triangle generation on only the occupied voxels.
|
||||
auto d_compVoxelArray = torch::zeros({activeVoxels}, opt);
|
||||
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
|
||||
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelOccupiedScan.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
numVoxels);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Scan d_voxelVerts array to generate offsets of vertices for each voxel
|
||||
auto d_voxelVertsScan = torch::zeros({numVoxels}, opt);
|
||||
ThrustScanWrapper(
|
||||
d_voxelVertsScan.data_ptr<int>(),
|
||||
d_voxelVerts.data_ptr<int>(),
|
||||
numVoxels);
|
||||
|
||||
// total number of vertices
|
||||
lastElement = d_voxelVerts[numVoxels - 1].cpu().item<int>();
|
||||
lastScan = d_voxelVertsScan[numVoxels - 1].cpu().item<int>();
|
||||
int totalVerts = lastElement + lastScan;
|
||||
|
||||
// Execute "GenerateFacesKernel" kernel
|
||||
// This runs only on the occupied voxels.
|
||||
// It looks up the field values and generates the triangle data.
|
||||
torch::Tensor verts = torch::zeros({totalVerts, 3}, vol.options());
|
||||
torch::Tensor faces = torch::zeros({totalVerts / 3, 3}, opt_long);
|
||||
|
||||
torch::Tensor ids = torch::zeros({totalVerts}, opt_long);
|
||||
|
||||
dim3 grid2((activeVoxels + threads - 1) / threads, 1, 1);
|
||||
if (grid2.x > 65535) {
|
||||
grid2.x = 65535;
|
||||
}
|
||||
|
||||
GenerateFacesKernel<<<grid2, threads, 0, stream>>>(
|
||||
verts.packed_accessor32<float, 2, at::RestrictPtrTraits>(),
|
||||
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
|
||||
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
|
||||
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelVertsScan.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
activeVoxels,
|
||||
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
|
||||
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),
|
||||
isolevel);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return std::make_tuple(verts, faces, ids);
|
||||
}
|
@ -23,17 +23,40 @@
|
||||
// the points are within a volume.
|
||||
//
|
||||
// Returns:
|
||||
// vertices: List of N FloatTensors of vertices
|
||||
// faces: List of N LongTensors of faces
|
||||
// vertices: (N_verts, 3) FloatTensor of vertices
|
||||
// faces: (N_faces, 3) LongTensor of faces
|
||||
// ids: (N_verts,) LongTensor used to identify each vertex and deduplication
|
||||
// to avoid floating point precision issues.
|
||||
// For Cuda, will be used to dedupe redundant vertices.
|
||||
// For cpp implementation, this tensor is just a placeholder.
|
||||
|
||||
// CPU implementation
|
||||
std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCpu(
|
||||
const at::Tensor& vol,
|
||||
const float isolevel);
|
||||
|
||||
// CUDA implementation
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
|
||||
const at::Tensor& vol,
|
||||
const float isolevel);
|
||||
|
||||
// Implementation which is exposed
|
||||
inline std::tuple<at::Tensor, at::Tensor> MarchingCubes(
|
||||
inline std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubes(
|
||||
const at::Tensor& vol,
|
||||
const float isolevel) {
|
||||
if (vol.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(vol);
|
||||
const int D = vol.size(0);
|
||||
const int H = vol.size(1);
|
||||
const int W = vol.size(2);
|
||||
if (D > 1024 || H > 1024 || W > 1024) {
|
||||
AT_ERROR("Maximum volume size allowed 1K x 1K x 1K");
|
||||
}
|
||||
return MarchingCubesCuda(vol.contiguous(), isolevel);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return MarchingCubesCpu(vol.contiguous(), isolevel);
|
||||
}
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "marching_cubes/marching_cubes_utils.h"
|
||||
#include "marching_cubes/tables.h"
|
||||
|
||||
// Cpu implementation for Marching Cubes
|
||||
// Args:
|
||||
@ -21,10 +22,11 @@
|
||||
// whether points are within a volume.
|
||||
//
|
||||
// Returns:
|
||||
// vertices: a float tensor of shape (N, 3) for positions of the mesh
|
||||
// faces: a long tensor of shape (N, 3) for indices of the face vertices
|
||||
// vertices: a float tensor of shape (N_verts, 3) for positions of the mesh
|
||||
// faces: a long tensor of shape (N_faces, 3) for indices of the face
|
||||
// ids: a long tensor of shape (N_verts) as placeholder
|
||||
//
|
||||
std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCpu(
|
||||
const at::Tensor& vol,
|
||||
const float isolevel) {
|
||||
// volume shapes
|
||||
@ -48,7 +50,7 @@ std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
|
||||
for (int x = 0; x < W - 1; x++) {
|
||||
Cube cube(x, y, z, vol_a, isolevel);
|
||||
// Cube is entirely in/out of the surface
|
||||
if (_FACE_TABLE[cube.cubeindex][0] == -1) {
|
||||
if (_FACE_TABLE[cube.cubeindex][0] == 255) {
|
||||
continue;
|
||||
}
|
||||
// store all boundary vertices that intersect with the edges
|
||||
@ -58,7 +60,7 @@ std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
|
||||
std::vector<Vertex> ps;
|
||||
|
||||
// Interpolate the vertices where the surface intersects with the cube
|
||||
for (int j = 0; _FACE_TABLE[cube.cubeindex][j] != -1; j++) {
|
||||
for (int j = 0; _FACE_TABLE[cube.cubeindex][j] != 255; j++) {
|
||||
const int e = _FACE_TABLE[cube.cubeindex][j];
|
||||
interp_points[e] = cube.VertexInterp(isolevel, e, vol_a);
|
||||
|
||||
@ -95,6 +97,7 @@ std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
|
||||
const int n_vertices = verts.size();
|
||||
const int64_t n_faces = (int64_t)faces.size() / 3;
|
||||
auto vert_tensor = torch::zeros({n_vertices, 3}, torch::kFloat);
|
||||
auto id_tensor = torch::zeros({n_vertices}, torch::kInt64); // placeholder
|
||||
auto face_tensor = torch::zeros({n_faces, 3}, torch::kInt64);
|
||||
|
||||
auto vert_a = vert_tensor.accessor<float, 2>();
|
||||
@ -111,5 +114,5 @@ std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
|
||||
face_a[i][2] = faces.at(i * 3 + 2);
|
||||
}
|
||||
|
||||
return std::make_tuple(vert_tensor, face_tensor);
|
||||
return std::make_tuple(vert_tensor, face_tensor, id_tensor);
|
||||
}
|
||||
|
@ -12,291 +12,11 @@
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include "ATen/core/TensorAccessor.h"
|
||||
#include "marching_cubes/tables.h"
|
||||
|
||||
// EPS: Used to assess whether two float values are close
|
||||
const float EPS = 1e-5;
|
||||
|
||||
// A table mapping from cubeindex to a list of face configurations.
|
||||
// Each list contains at most 5 faces, where each face is represented with
|
||||
// 3 consecutive numbers
|
||||
// Table taken from http://paulbourke.net/geometry/polygonise/
|
||||
const int _FACE_TABLE[256][16] = {
|
||||
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 8, 3, 9, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 8, 3, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 2, 10, 0, 2, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{2, 8, 3, 2, 10, 8, 10, 9, 8, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 11, 2, 8, 11, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 9, 0, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 11, 2, 1, 9, 11, 9, 8, 11, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 10, 1, 11, 10, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 10, 1, 0, 8, 10, 8, 11, 10, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 9, 0, 3, 11, 9, 11, 10, 9, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 3, 0, 7, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 1, 9, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 1, 9, 4, 7, 1, 7, 3, 1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 2, 10, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 4, 7, 3, 0, 4, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 2, 10, 9, 0, 2, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1},
|
||||
{2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, -1, -1, -1, -1},
|
||||
{8, 4, 7, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{11, 4, 7, 11, 2, 4, 2, 0, 4, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 0, 1, 8, 4, 7, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, -1, -1, -1, -1},
|
||||
{3, 10, 1, 3, 11, 10, 7, 8, 4, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, -1, -1, -1, -1},
|
||||
{4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, -1, -1, -1, -1},
|
||||
{4, 7, 11, 4, 11, 9, 9, 11, 10, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 5, 4, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 5, 4, 1, 5, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{8, 5, 4, 8, 3, 5, 3, 1, 5, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 2, 10, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 0, 8, 1, 2, 10, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1},
|
||||
{5, 2, 10, 5, 4, 2, 4, 0, 2, -1, -1, -1, -1, -1, -1, -1},
|
||||
{2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, -1, -1, -1, -1},
|
||||
{9, 5, 4, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 11, 2, 0, 8, 11, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 5, 4, 0, 1, 5, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1},
|
||||
{2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, -1, -1, -1, -1},
|
||||
{10, 3, 11, 10, 1, 3, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, -1, -1, -1, -1},
|
||||
{5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, -1, -1, -1, -1},
|
||||
{5, 4, 8, 5, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 7, 8, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 3, 0, 9, 5, 3, 5, 7, 3, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 7, 8, 0, 1, 7, 1, 5, 7, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 7, 8, 9, 5, 7, 10, 1, 2, -1, -1, -1, -1, -1, -1, -1},
|
||||
{10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, -1, -1, -1, -1},
|
||||
{8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, -1, -1, -1, -1},
|
||||
{2, 10, 5, 2, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1},
|
||||
{7, 9, 5, 7, 8, 9, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, -1, -1, -1, -1},
|
||||
{2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, -1, -1, -1, -1},
|
||||
{11, 2, 1, 11, 1, 7, 7, 1, 5, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, -1, -1, -1, -1},
|
||||
{5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, -1},
|
||||
{11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, -1},
|
||||
{11, 10, 5, 7, 11, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 8, 3, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 0, 1, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 8, 3, 1, 9, 8, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 6, 5, 2, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 6, 5, 1, 2, 6, 3, 0, 8, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 6, 5, 9, 0, 6, 0, 2, 6, -1, -1, -1, -1, -1, -1, -1},
|
||||
{5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, -1, -1, -1, -1},
|
||||
{2, 3, 11, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{11, 0, 8, 11, 2, 0, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 1, 9, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1},
|
||||
{5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, -1, -1, -1, -1},
|
||||
{6, 3, 11, 6, 5, 3, 5, 1, 3, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, -1, -1, -1, -1},
|
||||
{3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, -1, -1, -1, -1},
|
||||
{6, 5, 9, 6, 9, 11, 11, 9, 8, -1, -1, -1, -1, -1, -1, -1},
|
||||
{5, 10, 6, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 3, 0, 4, 7, 3, 6, 5, 10, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 9, 0, 5, 10, 6, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1},
|
||||
{10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, -1, -1, -1, -1},
|
||||
{6, 1, 2, 6, 5, 1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, -1, -1, -1, -1},
|
||||
{8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, -1, -1, -1, -1},
|
||||
{7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, -1},
|
||||
{3, 11, 2, 7, 8, 4, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1},
|
||||
{5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, -1, -1, -1, -1},
|
||||
{0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1},
|
||||
{9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, -1},
|
||||
{8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, -1, -1, -1, -1},
|
||||
{5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, -1},
|
||||
{0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, -1},
|
||||
{6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, -1, -1, -1, -1},
|
||||
{10, 4, 9, 6, 4, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 10, 6, 4, 9, 10, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1},
|
||||
{10, 0, 1, 10, 6, 0, 6, 4, 0, -1, -1, -1, -1, -1, -1, -1},
|
||||
{8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, -1, -1, -1, -1},
|
||||
{1, 4, 9, 1, 2, 4, 2, 6, 4, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, -1, -1, -1, -1},
|
||||
{0, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{8, 3, 2, 8, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1},
|
||||
{10, 4, 9, 10, 6, 4, 11, 2, 3, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, -1, -1, -1, -1},
|
||||
{3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, -1, -1, -1, -1},
|
||||
{6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, -1},
|
||||
{9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, -1, -1, -1, -1},
|
||||
{8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, -1},
|
||||
{3, 11, 6, 3, 6, 0, 0, 6, 4, -1, -1, -1, -1, -1, -1, -1},
|
||||
{6, 4, 8, 11, 6, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{7, 10, 6, 7, 8, 10, 8, 9, 10, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, -1, -1, -1, -1},
|
||||
{10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, -1, -1, -1, -1},
|
||||
{10, 6, 7, 10, 7, 1, 1, 7, 3, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, -1, -1, -1, -1},
|
||||
{2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, -1},
|
||||
{7, 8, 0, 7, 0, 6, 6, 0, 2, -1, -1, -1, -1, -1, -1, -1},
|
||||
{7, 3, 2, 6, 7, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, -1, -1, -1, -1},
|
||||
{2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, -1},
|
||||
{1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, -1},
|
||||
{11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, -1, -1, -1, -1},
|
||||
{8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, -1},
|
||||
{0, 9, 1, 11, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, -1, -1, -1, -1},
|
||||
{7, 11, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 0, 8, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 1, 9, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{8, 1, 9, 8, 3, 1, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1},
|
||||
{10, 1, 2, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 2, 10, 3, 0, 8, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1},
|
||||
{2, 9, 0, 2, 10, 9, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1},
|
||||
{6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, -1, -1, -1, -1},
|
||||
{7, 2, 3, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{7, 0, 8, 7, 6, 0, 6, 2, 0, -1, -1, -1, -1, -1, -1, -1},
|
||||
{2, 7, 6, 2, 3, 7, 0, 1, 9, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, -1, -1, -1, -1},
|
||||
{10, 7, 6, 10, 1, 7, 1, 3, 7, -1, -1, -1, -1, -1, -1, -1},
|
||||
{10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, -1, -1, -1, -1},
|
||||
{0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, -1, -1, -1, -1},
|
||||
{7, 6, 10, 7, 10, 8, 8, 10, 9, -1, -1, -1, -1, -1, -1, -1},
|
||||
{6, 8, 4, 11, 8, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 6, 11, 3, 0, 6, 0, 4, 6, -1, -1, -1, -1, -1, -1, -1},
|
||||
{8, 6, 11, 8, 4, 6, 9, 0, 1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, -1, -1, -1, -1},
|
||||
{6, 8, 4, 6, 11, 8, 2, 10, 1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, -1, -1, -1, -1},
|
||||
{4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, -1, -1, -1, -1},
|
||||
{10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, -1},
|
||||
{8, 2, 3, 8, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, -1, -1, -1, -1},
|
||||
{1, 9, 4, 1, 4, 2, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1},
|
||||
{8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, -1, -1, -1, -1},
|
||||
{10, 1, 0, 10, 0, 6, 6, 0, 4, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, -1},
|
||||
{10, 9, 4, 6, 10, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 9, 5, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 8, 3, 4, 9, 5, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1},
|
||||
{5, 0, 1, 5, 4, 0, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1},
|
||||
{11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, -1, -1, -1, -1},
|
||||
{9, 5, 4, 10, 1, 2, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1},
|
||||
{6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, -1, -1, -1, -1},
|
||||
{7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, -1, -1, -1, -1},
|
||||
{3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, -1},
|
||||
{7, 2, 3, 7, 6, 2, 5, 4, 9, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, -1, -1, -1, -1},
|
||||
{3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, -1, -1, -1, -1},
|
||||
{6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, -1},
|
||||
{9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, -1, -1, -1, -1},
|
||||
{1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, -1},
|
||||
{4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, -1},
|
||||
{7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, -1, -1, -1, -1},
|
||||
{6, 9, 5, 6, 11, 9, 11, 8, 9, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, -1, -1, -1, -1},
|
||||
{0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, -1, -1, -1, -1},
|
||||
{6, 11, 3, 6, 3, 5, 5, 3, 1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, -1, -1, -1, -1},
|
||||
{0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, -1},
|
||||
{11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, -1},
|
||||
{6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, -1, -1, -1, -1},
|
||||
{5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, -1, -1, -1, -1},
|
||||
{9, 5, 6, 9, 6, 0, 0, 6, 2, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, -1},
|
||||
{1, 5, 6, 2, 1, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, -1},
|
||||
{10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, -1, -1, -1, -1},
|
||||
{0, 3, 8, 5, 6, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{10, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{11, 5, 10, 7, 5, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{11, 5, 10, 11, 7, 5, 8, 3, 0, -1, -1, -1, -1, -1, -1, -1},
|
||||
{5, 11, 7, 5, 10, 11, 1, 9, 0, -1, -1, -1, -1, -1, -1, -1},
|
||||
{10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, -1, -1, -1, -1},
|
||||
{11, 1, 2, 11, 7, 1, 7, 5, 1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, -1, -1, -1, -1},
|
||||
{9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, -1, -1, -1, -1},
|
||||
{7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, -1},
|
||||
{2, 5, 10, 2, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1},
|
||||
{8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, -1, -1, -1, -1},
|
||||
{9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, -1, -1, -1, -1},
|
||||
{9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, -1},
|
||||
{1, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 8, 7, 0, 7, 1, 1, 7, 5, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 0, 3, 9, 3, 5, 5, 3, 7, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 8, 7, 5, 9, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{5, 8, 4, 5, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1},
|
||||
{5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, -1, -1, -1, -1},
|
||||
{0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, -1, -1, -1, -1},
|
||||
{10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, -1},
|
||||
{2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, -1, -1, -1, -1},
|
||||
{0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, -1},
|
||||
{0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, -1},
|
||||
{9, 4, 5, 2, 11, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, -1, -1, -1, -1},
|
||||
{5, 10, 2, 5, 2, 4, 4, 2, 0, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, -1},
|
||||
{5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, -1, -1, -1, -1},
|
||||
{8, 4, 5, 8, 5, 3, 3, 5, 1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 4, 5, 1, 0, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, -1, -1, -1, -1},
|
||||
{9, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 11, 7, 4, 9, 11, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, -1, -1, -1, -1},
|
||||
{1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, -1, -1, -1, -1},
|
||||
{3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, -1},
|
||||
{4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, -1, -1, -1, -1},
|
||||
{9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, -1},
|
||||
{11, 7, 4, 11, 4, 2, 2, 4, 0, -1, -1, -1, -1, -1, -1, -1},
|
||||
{11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, -1, -1, -1, -1},
|
||||
{2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, -1, -1, -1, -1},
|
||||
{9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, -1},
|
||||
{3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, -1},
|
||||
{1, 10, 2, 8, 7, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 9, 1, 4, 1, 7, 7, 1, 3, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, -1, -1, -1, -1},
|
||||
{4, 0, 3, 7, 4, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{4, 8, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 0, 9, 3, 9, 11, 11, 9, 10, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 1, 10, 0, 10, 8, 8, 10, 11, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 1, 10, 11, 3, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 2, 11, 1, 11, 9, 9, 11, 8, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, -1, -1, -1, -1},
|
||||
{0, 2, 11, 8, 0, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{2, 3, 8, 2, 8, 10, 10, 8, 9, -1, -1, -1, -1, -1, -1, -1},
|
||||
{9, 10, 2, 0, 9, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, -1, -1, -1, -1},
|
||||
{1, 10, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{1, 3, 8, 9, 1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{0, 3, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
|
||||
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}};
|
||||
|
||||
// Table mapping each edge to the corresponding cube vertices
|
||||
const int _EDGE_TO_VERTICES[12][2] = {
|
||||
{0, 1},
|
||||
{1, 5},
|
||||
{4, 5},
|
||||
{0, 4},
|
||||
{2, 3},
|
||||
{3, 7},
|
||||
{6, 7},
|
||||
{2, 6},
|
||||
{0, 2},
|
||||
{1, 3},
|
||||
{5, 7},
|
||||
{4, 6},
|
||||
};
|
||||
|
||||
// Table mapping from 0-7 to v0-v7 in cube.vertices
|
||||
const int _INDEX_TABLE[8] = {0, 1, 5, 4, 2, 3, 7, 6};
|
||||
|
||||
// Data structures for the marching cubes
|
||||
struct Vertex {
|
||||
// Constructor used when performing marching cube in each cell
|
||||
|
294
pytorch3d/csrc/marching_cubes/tables.h
Normal file
294
pytorch3d/csrc/marching_cubes/tables.h
Normal file
@ -0,0 +1,294 @@
|
||||
/*
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
using uint = unsigned int;
|
||||
|
||||
// A table mapping from cubeindex to a list of face configurations.
|
||||
// Each list contains at most 5 faces, where each face is represented with
|
||||
// 3 consecutive numbers
|
||||
// Table adapted from http://paulbourke.net/geometry/polygonise/
|
||||
//
|
||||
#define X 255
|
||||
const unsigned char _FACE_TABLE[256][16] = {
|
||||
{X, X, X, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 8, 3, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 1, 9, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{1, 8, 3, 9, 8, 1, X, X, X, X, X, X, X, X, X, X},
|
||||
{1, 2, 10, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 8, 3, 1, 2, 10, X, X, X, X, X, X, X, X, X, X},
|
||||
{9, 2, 10, 0, 2, 9, X, X, X, X, X, X, X, X, X, X},
|
||||
{2, 8, 3, 2, 10, 8, 10, 9, 8, X, X, X, X, X, X, X},
|
||||
{3, 11, 2, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 11, 2, 8, 11, 0, X, X, X, X, X, X, X, X, X, X},
|
||||
{1, 9, 0, 2, 3, 11, X, X, X, X, X, X, X, X, X, X},
|
||||
{1, 11, 2, 1, 9, 11, 9, 8, 11, X, X, X, X, X, X, X},
|
||||
{3, 10, 1, 11, 10, 3, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 10, 1, 0, 8, 10, 8, 11, 10, X, X, X, X, X, X, X},
|
||||
{3, 9, 0, 3, 11, 9, 11, 10, 9, X, X, X, X, X, X, X},
|
||||
{9, 8, 10, 10, 8, 11, X, X, X, X, X, X, X, X, X, X},
|
||||
{4, 7, 8, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{4, 3, 0, 7, 3, 4, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 1, 9, 8, 4, 7, X, X, X, X, X, X, X, X, X, X},
|
||||
{4, 1, 9, 4, 7, 1, 7, 3, 1, X, X, X, X, X, X, X},
|
||||
{1, 2, 10, 8, 4, 7, X, X, X, X, X, X, X, X, X, X},
|
||||
{3, 4, 7, 3, 0, 4, 1, 2, 10, X, X, X, X, X, X, X},
|
||||
{9, 2, 10, 9, 0, 2, 8, 4, 7, X, X, X, X, X, X, X},
|
||||
{2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, X, X, X, X},
|
||||
{8, 4, 7, 3, 11, 2, X, X, X, X, X, X, X, X, X, X},
|
||||
{11, 4, 7, 11, 2, 4, 2, 0, 4, X, X, X, X, X, X, X},
|
||||
{9, 0, 1, 8, 4, 7, 2, 3, 11, X, X, X, X, X, X, X},
|
||||
{4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, X, X, X, X},
|
||||
{3, 10, 1, 3, 11, 10, 7, 8, 4, X, X, X, X, X, X, X},
|
||||
{1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, X, X, X, X},
|
||||
{4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, X, X, X, X},
|
||||
{4, 7, 11, 4, 11, 9, 9, 11, 10, X, X, X, X, X, X, X},
|
||||
{9, 5, 4, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{9, 5, 4, 0, 8, 3, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 5, 4, 1, 5, 0, X, X, X, X, X, X, X, X, X, X},
|
||||
{8, 5, 4, 8, 3, 5, 3, 1, 5, X, X, X, X, X, X, X},
|
||||
{1, 2, 10, 9, 5, 4, X, X, X, X, X, X, X, X, X, X},
|
||||
{3, 0, 8, 1, 2, 10, 4, 9, 5, X, X, X, X, X, X, X},
|
||||
{5, 2, 10, 5, 4, 2, 4, 0, 2, X, X, X, X, X, X, X},
|
||||
{2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, X, X, X, X},
|
||||
{9, 5, 4, 2, 3, 11, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 11, 2, 0, 8, 11, 4, 9, 5, X, X, X, X, X, X, X},
|
||||
{0, 5, 4, 0, 1, 5, 2, 3, 11, X, X, X, X, X, X, X},
|
||||
{2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, X, X, X, X},
|
||||
{10, 3, 11, 10, 1, 3, 9, 5, 4, X, X, X, X, X, X, X},
|
||||
{4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, X, X, X, X},
|
||||
{5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, X, X, X, X},
|
||||
{5, 4, 8, 5, 8, 10, 10, 8, 11, X, X, X, X, X, X, X},
|
||||
{9, 7, 8, 5, 7, 9, X, X, X, X, X, X, X, X, X, X},
|
||||
{9, 3, 0, 9, 5, 3, 5, 7, 3, X, X, X, X, X, X, X},
|
||||
{0, 7, 8, 0, 1, 7, 1, 5, 7, X, X, X, X, X, X, X},
|
||||
{1, 5, 3, 3, 5, 7, X, X, X, X, X, X, X, X, X, X},
|
||||
{9, 7, 8, 9, 5, 7, 10, 1, 2, X, X, X, X, X, X, X},
|
||||
{10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, X, X, X, X},
|
||||
{8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, X, X, X, X},
|
||||
{2, 10, 5, 2, 5, 3, 3, 5, 7, X, X, X, X, X, X, X},
|
||||
{7, 9, 5, 7, 8, 9, 3, 11, 2, X, X, X, X, X, X, X},
|
||||
{9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, X, X, X, X},
|
||||
{2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, X, X, X, X},
|
||||
{11, 2, 1, 11, 1, 7, 7, 1, 5, X, X, X, X, X, X, X},
|
||||
{9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, X, X, X, X},
|
||||
{5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, X},
|
||||
{11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, X},
|
||||
{11, 10, 5, 7, 11, 5, X, X, X, X, X, X, X, X, X, X},
|
||||
{10, 6, 5, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 8, 3, 5, 10, 6, X, X, X, X, X, X, X, X, X, X},
|
||||
{9, 0, 1, 5, 10, 6, X, X, X, X, X, X, X, X, X, X},
|
||||
{1, 8, 3, 1, 9, 8, 5, 10, 6, X, X, X, X, X, X, X},
|
||||
{1, 6, 5, 2, 6, 1, X, X, X, X, X, X, X, X, X, X},
|
||||
{1, 6, 5, 1, 2, 6, 3, 0, 8, X, X, X, X, X, X, X},
|
||||
{9, 6, 5, 9, 0, 6, 0, 2, 6, X, X, X, X, X, X, X},
|
||||
{5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, X, X, X, X},
|
||||
{2, 3, 11, 10, 6, 5, X, X, X, X, X, X, X, X, X, X},
|
||||
{11, 0, 8, 11, 2, 0, 10, 6, 5, X, X, X, X, X, X, X},
|
||||
{0, 1, 9, 2, 3, 11, 5, 10, 6, X, X, X, X, X, X, X},
|
||||
{5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, X, X, X, X},
|
||||
{6, 3, 11, 6, 5, 3, 5, 1, 3, X, X, X, X, X, X, X},
|
||||
{0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, X, X, X, X},
|
||||
{3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, X, X, X, X},
|
||||
{6, 5, 9, 6, 9, 11, 11, 9, 8, X, X, X, X, X, X, X},
|
||||
{5, 10, 6, 4, 7, 8, X, X, X, X, X, X, X, X, X, X},
|
||||
{4, 3, 0, 4, 7, 3, 6, 5, 10, X, X, X, X, X, X, X},
|
||||
{1, 9, 0, 5, 10, 6, 8, 4, 7, X, X, X, X, X, X, X},
|
||||
{10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, X, X, X, X},
|
||||
{6, 1, 2, 6, 5, 1, 4, 7, 8, X, X, X, X, X, X, X},
|
||||
{1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, X, X, X, X},
|
||||
{8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, X, X, X, X},
|
||||
{7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, X},
|
||||
{3, 11, 2, 7, 8, 4, 10, 6, 5, X, X, X, X, X, X, X},
|
||||
{5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, X, X, X, X},
|
||||
{0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, X, X, X, X},
|
||||
{9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, X},
|
||||
{8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, X, X, X, X},
|
||||
{5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, X},
|
||||
{0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, X},
|
||||
{6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, X, X, X, X},
|
||||
{10, 4, 9, 6, 4, 10, X, X, X, X, X, X, X, X, X, X},
|
||||
{4, 10, 6, 4, 9, 10, 0, 8, 3, X, X, X, X, X, X, X},
|
||||
{10, 0, 1, 10, 6, 0, 6, 4, 0, X, X, X, X, X, X, X},
|
||||
{8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, X, X, X, X},
|
||||
{1, 4, 9, 1, 2, 4, 2, 6, 4, X, X, X, X, X, X, X},
|
||||
{3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, X, X, X, X},
|
||||
{0, 2, 4, 4, 2, 6, X, X, X, X, X, X, X, X, X, X},
|
||||
{8, 3, 2, 8, 2, 4, 4, 2, 6, X, X, X, X, X, X, X},
|
||||
{10, 4, 9, 10, 6, 4, 11, 2, 3, X, X, X, X, X, X, X},
|
||||
{0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, X, X, X, X},
|
||||
{3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, X, X, X, X},
|
||||
{6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, X},
|
||||
{9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, X, X, X, X},
|
||||
{8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, X},
|
||||
{3, 11, 6, 3, 6, 0, 0, 6, 4, X, X, X, X, X, X, X},
|
||||
{6, 4, 8, 11, 6, 8, X, X, X, X, X, X, X, X, X, X},
|
||||
{7, 10, 6, 7, 8, 10, 8, 9, 10, X, X, X, X, X, X, X},
|
||||
{0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, X, X, X, X},
|
||||
{10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, X, X, X, X},
|
||||
{10, 6, 7, 10, 7, 1, 1, 7, 3, X, X, X, X, X, X, X},
|
||||
{1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, X, X, X, X},
|
||||
{2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, X},
|
||||
{7, 8, 0, 7, 0, 6, 6, 0, 2, X, X, X, X, X, X, X},
|
||||
{7, 3, 2, 6, 7, 2, X, X, X, X, X, X, X, X, X, X},
|
||||
{2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, X, X, X, X},
|
||||
{2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, X},
|
||||
{1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, X},
|
||||
{11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, X, X, X, X},
|
||||
{8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, X},
|
||||
{0, 9, 1, 11, 6, 7, X, X, X, X, X, X, X, X, X, X},
|
||||
{7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, X, X, X, X},
|
||||
{7, 11, 6, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{7, 6, 11, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{3, 0, 8, 11, 7, 6, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 1, 9, 11, 7, 6, X, X, X, X, X, X, X, X, X, X},
|
||||
{8, 1, 9, 8, 3, 1, 11, 7, 6, X, X, X, X, X, X, X},
|
||||
{10, 1, 2, 6, 11, 7, X, X, X, X, X, X, X, X, X, X},
|
||||
{1, 2, 10, 3, 0, 8, 6, 11, 7, X, X, X, X, X, X, X},
|
||||
{2, 9, 0, 2, 10, 9, 6, 11, 7, X, X, X, X, X, X, X},
|
||||
{6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, X, X, X, X},
|
||||
{7, 2, 3, 6, 2, 7, X, X, X, X, X, X, X, X, X, X},
|
||||
{7, 0, 8, 7, 6, 0, 6, 2, 0, X, X, X, X, X, X, X},
|
||||
{2, 7, 6, 2, 3, 7, 0, 1, 9, X, X, X, X, X, X, X},
|
||||
{1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, X, X, X, X},
|
||||
{10, 7, 6, 10, 1, 7, 1, 3, 7, X, X, X, X, X, X, X},
|
||||
{10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, X, X, X, X},
|
||||
{0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, X, X, X, X},
|
||||
{7, 6, 10, 7, 10, 8, 8, 10, 9, X, X, X, X, X, X, X},
|
||||
{6, 8, 4, 11, 8, 6, X, X, X, X, X, X, X, X, X, X},
|
||||
{3, 6, 11, 3, 0, 6, 0, 4, 6, X, X, X, X, X, X, X},
|
||||
{8, 6, 11, 8, 4, 6, 9, 0, 1, X, X, X, X, X, X, X},
|
||||
{9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, X, X, X, X},
|
||||
{6, 8, 4, 6, 11, 8, 2, 10, 1, X, X, X, X, X, X, X},
|
||||
{1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, X, X, X, X},
|
||||
{4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, X, X, X, X},
|
||||
{10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, X},
|
||||
{8, 2, 3, 8, 4, 2, 4, 6, 2, X, X, X, X, X, X, X},
|
||||
{0, 4, 2, 4, 6, 2, X, X, X, X, X, X, X, X, X, X},
|
||||
{1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, X, X, X, X},
|
||||
{1, 9, 4, 1, 4, 2, 2, 4, 6, X, X, X, X, X, X, X},
|
||||
{8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, X, X, X, X},
|
||||
{10, 1, 0, 10, 0, 6, 6, 0, 4, X, X, X, X, X, X, X},
|
||||
{4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, X},
|
||||
{10, 9, 4, 6, 10, 4, X, X, X, X, X, X, X, X, X, X},
|
||||
{4, 9, 5, 7, 6, 11, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 8, 3, 4, 9, 5, 11, 7, 6, X, X, X, X, X, X, X},
|
||||
{5, 0, 1, 5, 4, 0, 7, 6, 11, X, X, X, X, X, X, X},
|
||||
{11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, X, X, X, X},
|
||||
{9, 5, 4, 10, 1, 2, 7, 6, 11, X, X, X, X, X, X, X},
|
||||
{6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, X, X, X, X},
|
||||
{7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, X, X, X, X},
|
||||
{3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, X},
|
||||
{7, 2, 3, 7, 6, 2, 5, 4, 9, X, X, X, X, X, X, X},
|
||||
{9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, X, X, X, X},
|
||||
{3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, X, X, X, X},
|
||||
{6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, X},
|
||||
{9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, X, X, X, X},
|
||||
{1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, X},
|
||||
{4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, X},
|
||||
{7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, X, X, X, X},
|
||||
{6, 9, 5, 6, 11, 9, 11, 8, 9, X, X, X, X, X, X, X},
|
||||
{3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, X, X, X, X},
|
||||
{0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, X, X, X, X},
|
||||
{6, 11, 3, 6, 3, 5, 5, 3, 1, X, X, X, X, X, X, X},
|
||||
{1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, X, X, X, X},
|
||||
{0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, X},
|
||||
{11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, X},
|
||||
{6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, X, X, X, X},
|
||||
{5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, X, X, X, X},
|
||||
{9, 5, 6, 9, 6, 0, 0, 6, 2, X, X, X, X, X, X, X},
|
||||
{1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, X},
|
||||
{1, 5, 6, 2, 1, 6, X, X, X, X, X, X, X, X, X, X},
|
||||
{1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, X},
|
||||
{10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, X, X, X, X},
|
||||
{0, 3, 8, 5, 6, 10, X, X, X, X, X, X, X, X, X, X},
|
||||
{10, 5, 6, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{11, 5, 10, 7, 5, 11, X, X, X, X, X, X, X, X, X, X},
|
||||
{11, 5, 10, 11, 7, 5, 8, 3, 0, X, X, X, X, X, X, X},
|
||||
{5, 11, 7, 5, 10, 11, 1, 9, 0, X, X, X, X, X, X, X},
|
||||
{10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, X, X, X, X},
|
||||
{11, 1, 2, 11, 7, 1, 7, 5, 1, X, X, X, X, X, X, X},
|
||||
{0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, X, X, X, X},
|
||||
{9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, X, X, X, X},
|
||||
{7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, X},
|
||||
{2, 5, 10, 2, 3, 5, 3, 7, 5, X, X, X, X, X, X, X},
|
||||
{8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, X, X, X, X},
|
||||
{9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, X, X, X, X},
|
||||
{9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, X},
|
||||
{1, 3, 5, 3, 7, 5, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 8, 7, 0, 7, 1, 1, 7, 5, X, X, X, X, X, X, X},
|
||||
{9, 0, 3, 9, 3, 5, 5, 3, 7, X, X, X, X, X, X, X},
|
||||
{9, 8, 7, 5, 9, 7, X, X, X, X, X, X, X, X, X, X},
|
||||
{5, 8, 4, 5, 10, 8, 10, 11, 8, X, X, X, X, X, X, X},
|
||||
{5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, X, X, X, X},
|
||||
{0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, X, X, X, X},
|
||||
{10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, X},
|
||||
{2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, X, X, X, X},
|
||||
{0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, X},
|
||||
{0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, X},
|
||||
{9, 4, 5, 2, 11, 3, X, X, X, X, X, X, X, X, X, X},
|
||||
{2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, X, X, X, X},
|
||||
{5, 10, 2, 5, 2, 4, 4, 2, 0, X, X, X, X, X, X, X},
|
||||
{3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, X},
|
||||
{5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, X, X, X, X},
|
||||
{8, 4, 5, 8, 5, 3, 3, 5, 1, X, X, X, X, X, X, X},
|
||||
{0, 4, 5, 1, 0, 5, X, X, X, X, X, X, X, X, X, X},
|
||||
{8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, X, X, X, X},
|
||||
{9, 4, 5, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{4, 11, 7, 4, 9, 11, 9, 10, 11, X, X, X, X, X, X, X},
|
||||
{0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, X, X, X, X},
|
||||
{1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, X, X, X, X},
|
||||
{3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, X},
|
||||
{4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, X, X, X, X},
|
||||
{9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, X},
|
||||
{11, 7, 4, 11, 4, 2, 2, 4, 0, X, X, X, X, X, X, X},
|
||||
{11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, X, X, X, X},
|
||||
{2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, X, X, X, X},
|
||||
{9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, X},
|
||||
{3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, X},
|
||||
{1, 10, 2, 8, 7, 4, X, X, X, X, X, X, X, X, X, X},
|
||||
{4, 9, 1, 4, 1, 7, 7, 1, 3, X, X, X, X, X, X, X},
|
||||
{4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, X, X, X, X},
|
||||
{4, 0, 3, 7, 4, 3, X, X, X, X, X, X, X, X, X, X},
|
||||
{4, 8, 7, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{9, 10, 8, 10, 11, 8, X, X, X, X, X, X, X, X, X, X},
|
||||
{3, 0, 9, 3, 9, 11, 11, 9, 10, X, X, X, X, X, X, X},
|
||||
{0, 1, 10, 0, 10, 8, 8, 10, 11, X, X, X, X, X, X, X},
|
||||
{3, 1, 10, 11, 3, 10, X, X, X, X, X, X, X, X, X, X},
|
||||
{1, 2, 11, 1, 11, 9, 9, 11, 8, X, X, X, X, X, X, X},
|
||||
{3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, X, X, X, X},
|
||||
{0, 2, 11, 8, 0, 11, X, X, X, X, X, X, X, X, X, X},
|
||||
{3, 2, 11, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{2, 3, 8, 2, 8, 10, 10, 8, 9, X, X, X, X, X, X, X},
|
||||
{9, 10, 2, 0, 9, 2, X, X, X, X, X, X, X, X, X, X},
|
||||
{2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, X, X, X, X},
|
||||
{1, 10, 2, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{1, 3, 8, 9, 1, 8, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 9, 1, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{0, 3, 8, X, X, X, X, X, X, X, X, X, X, X, X, X},
|
||||
{X, X, X, X, X, X, X, X, X, X, X, X, X, X, X, X}};
|
||||
#undef X
|
||||
|
||||
// Table mapping each edge to the corresponding cube vertices offsets
|
||||
const uint _EDGE_TO_VERTICES[12][2] = {
|
||||
{0, 1},
|
||||
{1, 5},
|
||||
{4, 5},
|
||||
{0, 4},
|
||||
{2, 3},
|
||||
{3, 7},
|
||||
{6, 7},
|
||||
{2, 6},
|
||||
{0, 2},
|
||||
{1, 3},
|
||||
{5, 7},
|
||||
{4, 6},
|
||||
};
|
||||
|
||||
// Table mapping from 0-7 to v0-v7 in cube.vertices
|
||||
const int _INDEX_TABLE[8] = {0, 1, 5, 4, 2, 3, 7, 6};
|
@ -230,18 +230,19 @@ def marching_cubes_naive(
|
||||
|
||||
|
||||
########################################
|
||||
# Marching Cubes Implementation in C++
|
||||
# Marching Cubes Implementation in C++/Cuda
|
||||
########################################
|
||||
class _marching_cubes(Function):
|
||||
"""
|
||||
Torch Function wrapper for marching_cubes C++ implementation
|
||||
Backward is not supported.
|
||||
Torch Function wrapper for marching_cubes implementation.
|
||||
This function is not differentiable. An autograd wrapper is used
|
||||
to ensure an error if user tries to get gradients.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vol, isolevel):
|
||||
verts, faces = _C.marching_cubes(vol, isolevel)
|
||||
return verts, faces
|
||||
verts, faces, ids = _C.marching_cubes(vol, isolevel)
|
||||
return verts, faces, ids
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_verts, grad_faces):
|
||||
@ -268,7 +269,6 @@ def marching_cubes(
|
||||
the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range
|
||||
[0, W-1] x [0, H-1] x [0, D-1]
|
||||
|
||||
|
||||
Returns:
|
||||
verts: [{V_0}, {V_1}, ...] List of N sets of vertices of shape (|V_i|, 3) in FloatTensor
|
||||
faces: [{F_0}, {F_1}, ...] List of N sets of faces of shape (|F_i|, 3) in LongTensors
|
||||
@ -279,7 +279,7 @@ def marching_cubes(
|
||||
vol = vol_batch[i]
|
||||
thresh = ((vol.max() + vol.min()) / 2).item() if isolevel is None else isolevel
|
||||
# pyre-fixme[16]: `_marching_cubes` has no attribute `apply`.
|
||||
verts, faces = _marching_cubes.apply(vol, thresh)
|
||||
verts, faces, ids = _marching_cubes.apply(vol, thresh)
|
||||
if len(faces) > 0 and len(verts) > 0:
|
||||
# Convert from world coordinates ([0, D-1], [0, H-1], [0, W-1]) to
|
||||
# local coordinates in the range [-1, 1]
|
||||
@ -289,6 +289,13 @@ def marching_cubes(
|
||||
.scale((vol.new_tensor([W, H, D])[None] - 1) * 0.5)
|
||||
.inverse()
|
||||
).transform_points(verts[None])[0]
|
||||
# deduplication for cuda
|
||||
if vol.is_cuda:
|
||||
unique_ids, inverse_idx = torch.unique(ids, return_inverse=True)
|
||||
verts_ = verts.new_zeros(unique_ids.shape[0], 3)
|
||||
verts_[inverse_idx] = verts
|
||||
verts = verts_
|
||||
faces = inverse_idx[faces]
|
||||
batched_verts.append(verts)
|
||||
batched_faces.append(faces)
|
||||
else:
|
||||
|
@ -14,10 +14,11 @@ def bm_marching_cubes() -> None:
|
||||
case_grid = {
|
||||
"algo_type": [
|
||||
"naive",
|
||||
"cextension",
|
||||
"extension",
|
||||
],
|
||||
"batch_size": [1, 5, 20],
|
||||
"V": [5, 10, 20],
|
||||
"batch_size": [1, 2],
|
||||
"V": [5, 10, 20, 100, 512],
|
||||
"device": ["cpu", "cuda:0"],
|
||||
}
|
||||
test_cases = itertools.product(*case_grid.values())
|
||||
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
|
||||
|
@ -37,7 +37,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts, expected_verts)
|
||||
self.assertClose(faces, expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts, expected_verts)
|
||||
self.assertClose(faces, expected_faces)
|
||||
@ -46,7 +45,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 0, 0, 0] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.5, 0, 0],
|
||||
@ -54,22 +52,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
[0, 0, 0.5],
|
||||
]
|
||||
)
|
||||
|
||||
expected_faces = torch.tensor([[0, 1, 2]])
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -92,7 +89,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -102,7 +98,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -128,7 +123,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -138,7 +132,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -164,7 +157,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -174,7 +166,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -198,7 +189,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -208,7 +198,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -239,7 +228,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -249,7 +237,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -285,7 +272,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -295,7 +281,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -324,7 +309,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -334,7 +318,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -363,7 +346,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -373,7 +355,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -401,7 +382,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -411,7 +391,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -441,7 +420,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -451,7 +429,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -483,7 +460,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -493,7 +469,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -525,7 +500,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -535,7 +509,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -565,7 +538,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -575,7 +547,6 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -613,7 +584,6 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -624,7 +594,6 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -641,8 +610,6 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
volume_data[0, 2, 2, 1] = 1
|
||||
volume_data[0, 2, 2, 2] = 1
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[1.0000, 0.9000, 1.0000],
|
||||
@ -720,11 +687,13 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
[17, 23, 19],
|
||||
]
|
||||
)
|
||||
verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, 0.9, return_local_coords=False)
|
||||
verts2, faces2 = marching_cubes(volume_data, 0.9, return_local_coords=False)
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
@ -736,7 +705,6 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
# Check all values are in the range [-1, 1]
|
||||
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume_data, 0.9, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -803,12 +771,14 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes(volume_data, 1, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, 1, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 5)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# Check all values are in the range [-1, 1]
|
||||
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
|
||||
|
||||
def test_sphere(self):
|
||||
@ -837,7 +807,6 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume, 64, return_local_coords=False)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -853,7 +822,6 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
# Check all values are in the range [-1, 1]
|
||||
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
|
||||
|
||||
# test C++ implementation
|
||||
verts, faces = marching_cubes(volume, 64, return_local_coords=True)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
@ -964,7 +932,7 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(surf, surf_c)
|
||||
|
||||
def test_ball_example(self):
|
||||
N = 15
|
||||
N = 30
|
||||
axis_tensor = torch.arange(0, N)
|
||||
X, Y, Z = torch.meshgrid(axis_tensor, axis_tensor, axis_tensor, indexing="ij")
|
||||
u = (X - 15) ** 2 + (Y - 15) ** 2 + (Z - 15) ** 2 - 8**2
|
||||
@ -975,14 +943,14 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(faces[0], faces2[0])
|
||||
|
||||
@staticmethod
|
||||
def marching_cubes_with_init(algo_type: str, batch_size: int, V: int):
|
||||
device = torch.device("cuda:0")
|
||||
def marching_cubes_with_init(algo_type: str, batch_size: int, V: int, device: str):
|
||||
device = torch.device(device)
|
||||
volume_data = torch.rand(
|
||||
(batch_size, V, V, V), dtype=torch.float32, device=device
|
||||
)
|
||||
algo_table = {
|
||||
"naive": marching_cubes_naive,
|
||||
"cextension": marching_cubes,
|
||||
"extension": marching_cubes,
|
||||
}
|
||||
|
||||
def convert():
|
||||
|
Loading…
x
Reference in New Issue
Block a user