Marching Cubes C++ torch extension

Summary:
Torch C++ extension for Marching Cubes

- Add torch C++ extension for marching cubes. Observe a speed up of ~255x-324x speed up (over varying batch sizes and spatial resolutions)

- Add C++ impl in existing unit-tests.

(Note: this ignores all push blocking failures!)

Reviewed By: kjchalup

Differential Revision: D39590638

fbshipit-source-id: e44d2852a24c2c398e5ea9db20f0dfaa1817e457
This commit is contained in:
Jiali Duan 2022-10-06 11:13:53 -07:00 committed by Facebook GitHub Bot
parent 850efdf706
commit 0d8608b9f9
7 changed files with 879 additions and 9 deletions

View File

@ -22,6 +22,7 @@
#include "interp_face_attrs/interp_face_attrs.h"
#include "iou_box3d/iou_box3d.h"
#include "knn/knn.h"
#include "marching_cubes/marching_cubes.h"
#include "mesh_normal_consistency/mesh_normal_consistency.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
#include "point_mesh/point_mesh_cuda.h"
@ -94,6 +95,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// 3D IoU
m.def("iou_box3d", &IoUBox3D);
// Marching cubes
m.def("marching_cubes", &MarchingCubes);
// Pulsar.
#ifdef PULSAR_LOGGING_ENABLED
c10::ShowLogInfoToStderr();

View File

@ -0,0 +1,39 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
#include <tuple>
#include <vector>
#include "utils/pytorch3d_cutils.h"
// Run Marching Cubes algorithm over a batch of volume scalar fields
// with a pre-defined threshold and return a mesh composed of vertices
// and faces for the mesh.
//
// Args:
// vol: FloatTensor of shape (D, H, W) giving a volume
// scalar grids.
// isolevel: isosurface value to use as the threshoold to determine whether
// the points are within a volume.
//
// Returns:
// vertices: List of N FloatTensors of vertices
// faces: List of N LongTensors of faces
// CPU implementation
std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
const at::Tensor& vol,
const float isolevel);
// Implementation which is exposed
inline std::tuple<at::Tensor, at::Tensor> MarchingCubes(
const at::Tensor& vol,
const float isolevel) {
return MarchingCubesCpu(vol.contiguous(), isolevel);
}

View File

@ -0,0 +1,115 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h>
#include <algorithm>
#include <array>
#include <cstring>
#include <unordered_map>
#include <vector>
#include "marching_cubes/marching_cubes_utils.h"
// Cpu implementation for Marching Cubes
// Args:
// vol: a Tensor of size (D, H, W) corresponding to a 3D scalar field
// isolevel: the isosurface value to use as the threshold to determine
// whether points are within a volume.
//
// Returns:
// vertices: a float tensor of shape (N, 3) for positions of the mesh
// faces: a long tensor of shape (N, 3) for indices of the face vertices
//
std::tuple<at::Tensor, at::Tensor> MarchingCubesCpu(
const at::Tensor& vol,
const float isolevel) {
// volume shapes
const int D = vol.size(0);
const int H = vol.size(1);
const int W = vol.size(2);
// Create tensor accessors
auto vol_a = vol.accessor<float, 3>();
// vpair_to_edge maps a pair of vertex ids to its corresponding edge id
std::unordered_map<std::pair<int, int>, int64_t> vpair_to_edge;
// edge_id_to_v maps from an edge id to a vertex position
std::unordered_map<int64_t, Vertex> edge_id_to_v;
// uniq_edge_id: used to remove redundant edge ids
std::unordered_map<int64_t, int64_t> uniq_edge_id;
std::vector<int64_t> faces; // store face indices
std::vector<Vertex> verts; // store vertex positions
// enumerate each cell in the 3d grid
for (int z = 0; z < D - 1; z++) {
for (int y = 0; y < H - 1; y++) {
for (int x = 0; x < W - 1; x++) {
Cube cube(x, y, z, vol_a, isolevel);
// Cube is entirely in/out of the surface
if (_FACE_TABLE[cube.cubeindex][0] == -1) {
continue;
}
// store all boundary vertices that intersect with the edges
std::array<Vertex, 12> interp_points;
// triangle vertex IDs and positions
std::vector<int64_t> tri;
std::vector<Vertex> ps;
// Interpolate the vertices where the surface intersects with the cube
for (int j = 0; _FACE_TABLE[cube.cubeindex][j] != -1; j++) {
const int e = _FACE_TABLE[cube.cubeindex][j];
interp_points[e] = cube.VertexInterp(isolevel, e, vol_a);
auto vpair = cube.GetVPairFromEdge(e, W, H);
if (!vpair_to_edge.count(vpair)) {
vpair_to_edge[vpair] = vpair_to_edge.size();
}
int64_t edge = vpair_to_edge[vpair];
tri.push_back(edge);
ps.push_back(interp_points[e]);
// Check if the triangle face is degenerate. A triangle face
// is degenerate if any of the two verices share the same 3D position
if ((j + 1) % 3 == 0 && ps[0] != ps[1] && ps[1] != ps[2] &&
ps[2] != ps[0]) {
for (int k = 0; k < 3; k++) {
int v = tri[k];
edge_id_to_v[tri.at(k)] = ps.at(k);
if (!uniq_edge_id.count(v)) {
uniq_edge_id[v] = verts.size();
verts.push_back(edge_id_to_v[v]);
}
faces.push_back(uniq_edge_id[v]);
}
tri.clear();
ps.clear();
}
} // endif
} // endfor x
} // endfor y
} // endfor z
// Collect returning tensor
const int n_vertices = verts.size();
const int64_t n_faces = (int64_t)faces.size() / 3;
auto vert_tensor = torch::zeros({n_vertices, 3}, torch::kFloat);
auto face_tensor = torch::zeros({n_faces, 3}, torch::kInt64);
auto vert_a = vert_tensor.accessor<float, 2>();
for (int i = 0; i < n_vertices; i++) {
vert_a[i][0] = verts.at(i).x;
vert_a[i][1] = verts.at(i).y;
vert_a[i][2] = verts.at(i).z;
}
auto face_a = face_tensor.accessor<int64_t, 2>();
for (int64_t i = 0; i < n_faces; i++) {
face_a[i][0] = faces.at(i * 3 + 0);
face_a[i][1] = faces.at(i * 3 + 1);
face_a[i][2] = faces.at(i * 3 + 2);
}
return std::make_tuple(vert_tensor, face_tensor);
}

View File

@ -0,0 +1,443 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
#include <cmath>
#include <cstdint>
#include <vector>
#include "ATen/core/TensorAccessor.h"
// EPS: Used to assess whether two float values are close
const float EPS = 1e-5;
// A table mapping from cubeindex to a list of face configurations.
// Each list contains at most 5 faces, where each face is represented with
// 3 consecutive numbers
// Table taken from http://paulbourke.net/geometry/polygonise/
const int _FACE_TABLE[256][16] = {
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{1, 8, 3, 9, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 8, 3, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{9, 2, 10, 0, 2, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{2, 8, 3, 2, 10, 8, 10, 9, 8, -1, -1, -1, -1, -1, -1, -1},
{3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 11, 2, 8, 11, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{1, 9, 0, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{1, 11, 2, 1, 9, 11, 9, 8, 11, -1, -1, -1, -1, -1, -1, -1},
{3, 10, 1, 11, 10, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 10, 1, 0, 8, 10, 8, 11, 10, -1, -1, -1, -1, -1, -1, -1},
{3, 9, 0, 3, 11, 9, 11, 10, 9, -1, -1, -1, -1, -1, -1, -1},
{9, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{4, 3, 0, 7, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 1, 9, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{4, 1, 9, 4, 7, 1, 7, 3, 1, -1, -1, -1, -1, -1, -1, -1},
{1, 2, 10, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{3, 4, 7, 3, 0, 4, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1},
{9, 2, 10, 9, 0, 2, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1},
{2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, -1, -1, -1, -1},
{8, 4, 7, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{11, 4, 7, 11, 2, 4, 2, 0, 4, -1, -1, -1, -1, -1, -1, -1},
{9, 0, 1, 8, 4, 7, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1},
{4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, -1, -1, -1, -1},
{3, 10, 1, 3, 11, 10, 7, 8, 4, -1, -1, -1, -1, -1, -1, -1},
{1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, -1, -1, -1, -1},
{4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, -1, -1, -1, -1},
{4, 7, 11, 4, 11, 9, 9, 11, 10, -1, -1, -1, -1, -1, -1, -1},
{9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{9, 5, 4, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 5, 4, 1, 5, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{8, 5, 4, 8, 3, 5, 3, 1, 5, -1, -1, -1, -1, -1, -1, -1},
{1, 2, 10, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{3, 0, 8, 1, 2, 10, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1},
{5, 2, 10, 5, 4, 2, 4, 0, 2, -1, -1, -1, -1, -1, -1, -1},
{2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, -1, -1, -1, -1},
{9, 5, 4, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 11, 2, 0, 8, 11, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1},
{0, 5, 4, 0, 1, 5, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1},
{2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, -1, -1, -1, -1},
{10, 3, 11, 10, 1, 3, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1},
{4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, -1, -1, -1, -1},
{5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, -1, -1, -1, -1},
{5, 4, 8, 5, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1},
{9, 7, 8, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{9, 3, 0, 9, 5, 3, 5, 7, 3, -1, -1, -1, -1, -1, -1, -1},
{0, 7, 8, 0, 1, 7, 1, 5, 7, -1, -1, -1, -1, -1, -1, -1},
{1, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{9, 7, 8, 9, 5, 7, 10, 1, 2, -1, -1, -1, -1, -1, -1, -1},
{10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, -1, -1, -1, -1},
{8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, -1, -1, -1, -1},
{2, 10, 5, 2, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1},
{7, 9, 5, 7, 8, 9, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1},
{9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, -1, -1, -1, -1},
{2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, -1, -1, -1, -1},
{11, 2, 1, 11, 1, 7, 7, 1, 5, -1, -1, -1, -1, -1, -1, -1},
{9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, -1, -1, -1, -1},
{5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, -1},
{11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, -1},
{11, 10, 5, 7, 11, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 8, 3, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{9, 0, 1, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{1, 8, 3, 1, 9, 8, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1},
{1, 6, 5, 2, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{1, 6, 5, 1, 2, 6, 3, 0, 8, -1, -1, -1, -1, -1, -1, -1},
{9, 6, 5, 9, 0, 6, 0, 2, 6, -1, -1, -1, -1, -1, -1, -1},
{5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, -1, -1, -1, -1},
{2, 3, 11, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{11, 0, 8, 11, 2, 0, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1},
{0, 1, 9, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1},
{5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, -1, -1, -1, -1},
{6, 3, 11, 6, 5, 3, 5, 1, 3, -1, -1, -1, -1, -1, -1, -1},
{0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, -1, -1, -1, -1},
{3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, -1, -1, -1, -1},
{6, 5, 9, 6, 9, 11, 11, 9, 8, -1, -1, -1, -1, -1, -1, -1},
{5, 10, 6, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{4, 3, 0, 4, 7, 3, 6, 5, 10, -1, -1, -1, -1, -1, -1, -1},
{1, 9, 0, 5, 10, 6, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1},
{10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, -1, -1, -1, -1},
{6, 1, 2, 6, 5, 1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1},
{1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, -1, -1, -1, -1},
{8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, -1, -1, -1, -1},
{7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, -1},
{3, 11, 2, 7, 8, 4, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1},
{5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, -1, -1, -1, -1},
{0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1},
{9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, -1},
{8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, -1, -1, -1, -1},
{5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, -1},
{0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, -1},
{6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, -1, -1, -1, -1},
{10, 4, 9, 6, 4, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{4, 10, 6, 4, 9, 10, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1},
{10, 0, 1, 10, 6, 0, 6, 4, 0, -1, -1, -1, -1, -1, -1, -1},
{8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, -1, -1, -1, -1},
{1, 4, 9, 1, 2, 4, 2, 6, 4, -1, -1, -1, -1, -1, -1, -1},
{3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, -1, -1, -1, -1},
{0, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{8, 3, 2, 8, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1},
{10, 4, 9, 10, 6, 4, 11, 2, 3, -1, -1, -1, -1, -1, -1, -1},
{0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, -1, -1, -1, -1},
{3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, -1, -1, -1, -1},
{6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, -1},
{9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, -1, -1, -1, -1},
{8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, -1},
{3, 11, 6, 3, 6, 0, 0, 6, 4, -1, -1, -1, -1, -1, -1, -1},
{6, 4, 8, 11, 6, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{7, 10, 6, 7, 8, 10, 8, 9, 10, -1, -1, -1, -1, -1, -1, -1},
{0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, -1, -1, -1, -1},
{10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, -1, -1, -1, -1},
{10, 6, 7, 10, 7, 1, 1, 7, 3, -1, -1, -1, -1, -1, -1, -1},
{1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, -1, -1, -1, -1},
{2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, -1},
{7, 8, 0, 7, 0, 6, 6, 0, 2, -1, -1, -1, -1, -1, -1, -1},
{7, 3, 2, 6, 7, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, -1, -1, -1, -1},
{2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, -1},
{1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, -1},
{11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, -1, -1, -1, -1},
{8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, -1},
{0, 9, 1, 11, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, -1, -1, -1, -1},
{7, 11, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{3, 0, 8, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 1, 9, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{8, 1, 9, 8, 3, 1, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1},
{10, 1, 2, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{1, 2, 10, 3, 0, 8, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1},
{2, 9, 0, 2, 10, 9, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1},
{6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, -1, -1, -1, -1},
{7, 2, 3, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{7, 0, 8, 7, 6, 0, 6, 2, 0, -1, -1, -1, -1, -1, -1, -1},
{2, 7, 6, 2, 3, 7, 0, 1, 9, -1, -1, -1, -1, -1, -1, -1},
{1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, -1, -1, -1, -1},
{10, 7, 6, 10, 1, 7, 1, 3, 7, -1, -1, -1, -1, -1, -1, -1},
{10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, -1, -1, -1, -1},
{0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, -1, -1, -1, -1},
{7, 6, 10, 7, 10, 8, 8, 10, 9, -1, -1, -1, -1, -1, -1, -1},
{6, 8, 4, 11, 8, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{3, 6, 11, 3, 0, 6, 0, 4, 6, -1, -1, -1, -1, -1, -1, -1},
{8, 6, 11, 8, 4, 6, 9, 0, 1, -1, -1, -1, -1, -1, -1, -1},
{9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, -1, -1, -1, -1},
{6, 8, 4, 6, 11, 8, 2, 10, 1, -1, -1, -1, -1, -1, -1, -1},
{1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, -1, -1, -1, -1},
{4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, -1, -1, -1, -1},
{10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, -1},
{8, 2, 3, 8, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1},
{0, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, -1, -1, -1, -1},
{1, 9, 4, 1, 4, 2, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1},
{8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, -1, -1, -1, -1},
{10, 1, 0, 10, 0, 6, 6, 0, 4, -1, -1, -1, -1, -1, -1, -1},
{4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, -1},
{10, 9, 4, 6, 10, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{4, 9, 5, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 8, 3, 4, 9, 5, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1},
{5, 0, 1, 5, 4, 0, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1},
{11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, -1, -1, -1, -1},
{9, 5, 4, 10, 1, 2, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1},
{6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, -1, -1, -1, -1},
{7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, -1, -1, -1, -1},
{3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, -1},
{7, 2, 3, 7, 6, 2, 5, 4, 9, -1, -1, -1, -1, -1, -1, -1},
{9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, -1, -1, -1, -1},
{3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, -1, -1, -1, -1},
{6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, -1},
{9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, -1, -1, -1, -1},
{1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, -1},
{4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, -1},
{7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, -1, -1, -1, -1},
{6, 9, 5, 6, 11, 9, 11, 8, 9, -1, -1, -1, -1, -1, -1, -1},
{3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, -1, -1, -1, -1},
{0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, -1, -1, -1, -1},
{6, 11, 3, 6, 3, 5, 5, 3, 1, -1, -1, -1, -1, -1, -1, -1},
{1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, -1, -1, -1, -1},
{0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, -1},
{11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, -1},
{6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, -1, -1, -1, -1},
{5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, -1, -1, -1, -1},
{9, 5, 6, 9, 6, 0, 0, 6, 2, -1, -1, -1, -1, -1, -1, -1},
{1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, -1},
{1, 5, 6, 2, 1, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, -1},
{10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, -1, -1, -1, -1},
{0, 3, 8, 5, 6, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{10, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{11, 5, 10, 7, 5, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{11, 5, 10, 11, 7, 5, 8, 3, 0, -1, -1, -1, -1, -1, -1, -1},
{5, 11, 7, 5, 10, 11, 1, 9, 0, -1, -1, -1, -1, -1, -1, -1},
{10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, -1, -1, -1, -1},
{11, 1, 2, 11, 7, 1, 7, 5, 1, -1, -1, -1, -1, -1, -1, -1},
{0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, -1, -1, -1, -1},
{9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, -1, -1, -1, -1},
{7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, -1},
{2, 5, 10, 2, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1},
{8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, -1, -1, -1, -1},
{9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, -1, -1, -1, -1},
{9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, -1},
{1, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 8, 7, 0, 7, 1, 1, 7, 5, -1, -1, -1, -1, -1, -1, -1},
{9, 0, 3, 9, 3, 5, 5, 3, 7, -1, -1, -1, -1, -1, -1, -1},
{9, 8, 7, 5, 9, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{5, 8, 4, 5, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1},
{5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, -1, -1, -1, -1},
{0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, -1, -1, -1, -1},
{10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, -1},
{2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, -1, -1, -1, -1},
{0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, -1},
{0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, -1},
{9, 4, 5, 2, 11, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, -1, -1, -1, -1},
{5, 10, 2, 5, 2, 4, 4, 2, 0, -1, -1, -1, -1, -1, -1, -1},
{3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, -1},
{5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, -1, -1, -1, -1},
{8, 4, 5, 8, 5, 3, 3, 5, 1, -1, -1, -1, -1, -1, -1, -1},
{0, 4, 5, 1, 0, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, -1, -1, -1, -1},
{9, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{4, 11, 7, 4, 9, 11, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1},
{0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, -1, -1, -1, -1},
{1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, -1, -1, -1, -1},
{3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, -1},
{4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, -1, -1, -1, -1},
{9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, -1},
{11, 7, 4, 11, 4, 2, 2, 4, 0, -1, -1, -1, -1, -1, -1, -1},
{11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, -1, -1, -1, -1},
{2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, -1, -1, -1, -1},
{9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, -1},
{3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, -1},
{1, 10, 2, 8, 7, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{4, 9, 1, 4, 1, 7, 7, 1, 3, -1, -1, -1, -1, -1, -1, -1},
{4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, -1, -1, -1, -1},
{4, 0, 3, 7, 4, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{4, 8, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{9, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{3, 0, 9, 3, 9, 11, 11, 9, 10, -1, -1, -1, -1, -1, -1, -1},
{0, 1, 10, 0, 10, 8, 8, 10, 11, -1, -1, -1, -1, -1, -1, -1},
{3, 1, 10, 11, 3, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{1, 2, 11, 1, 11, 9, 9, 11, 8, -1, -1, -1, -1, -1, -1, -1},
{3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, -1, -1, -1, -1},
{0, 2, 11, 8, 0, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{2, 3, 8, 2, 8, 10, 10, 8, 9, -1, -1, -1, -1, -1, -1, -1},
{9, 10, 2, 0, 9, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, -1, -1, -1, -1},
{1, 10, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{1, 3, 8, 9, 1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{0, 3, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}};
// Table mapping each edge to the corresponding cube vertices
const int _EDGE_TO_VERTICES[12][2] = {
{0, 1},
{1, 5},
{4, 5},
{0, 4},
{2, 3},
{3, 7},
{6, 7},
{2, 6},
{0, 2},
{1, 3},
{5, 7},
{4, 6},
};
// Table mapping from 0-7 to v0-v7 in cube.vertices
const int _INDEX_TABLE[8] = {0, 1, 5, 4, 2, 3, 7, 6};
// Data structures for the marching cubes
struct Vertex {
// Constructor used when performing marching cube in each cell
explicit Vertex(float x = 0.0f, float y = 0.0f, float z = 0.0f)
: x(x), y(y), z(z) {}
// The */+ operator overrides are used for vertex interpolation
Vertex operator*(float s) const {
return Vertex(x * s, y * s, z * s);
}
Vertex operator+(const Vertex& xyz) const {
return Vertex(x + xyz.x, y + xyz.y, z + xyz.z);
}
// The == operator overrides is used for checking degenerate triangles
bool operator==(const Vertex& xyz) const {
if (std::abs(x - xyz.x) < EPS && std::abs(y - xyz.y) < EPS &&
std::abs(z - xyz.z) < EPS) {
return true;
}
return false;
}
// vertex position
float x, y, z;
};
struct Cube {
// Edge and vertex convention:
// v4_______e4____________v5
// /| /|
// / | / |
// e7/ | e5/ |
// /___|______e6_________/ |
// v7| | |v6 |e9
// | | | |
// | |e8 |e10|
// e11| | | |
// | |_________________|___|
// | / v0 e0 | /v1
// | / | /
// | /e3 | /e1
// |/_____________________|/
// v3 e2 v2
Vertex p[8];
int x, y, z;
int cubeindex = 0;
Cube(
int x,
int y,
int z,
const at::TensorAccessor<float, 3>& vol_a,
const float isolevel)
: x(x), y(y), z(z) {
// vertex position (x, y, z) for v0-v1-v4-v5-v3-v2-v7-v6
for (int v = 0; v < 8; v++) {
p[v] = Vertex(x + (v & 1), y + (v >> 1 & 1), z + (v >> 2 & 1));
}
// Calculates cube configuration index given values of the cube vertices
for (int i = 0; i < 8; i++) {
const int idx = _INDEX_TABLE[i];
Vertex v = p[idx];
if (vol_a[v.z][v.y][v.x] < isolevel) {
cubeindex |= (1 << i);
}
}
}
// Linearly interpolate the position where an isosurface cuts an edge
// between two vertices, based on their scalar values
//
// Args:
// isolevel: float value used as threshold
// edge: edge (ID) to interpolate
// cube: current cube vertices
// vol_a: 3D scalar field
//
// Returns:
// point: interpolated vertex
Vertex VertexInterp(
float isolevel,
const int edge,
const at::TensorAccessor<float, 3>& vol_a) {
const int v1 = _EDGE_TO_VERTICES[edge][0];
const int v2 = _EDGE_TO_VERTICES[edge][1];
Vertex p1 = p[v1];
Vertex p2 = p[v2];
float val1 = vol_a[p1.z][p1.y][p1.x];
float val2 = vol_a[p2.z][p2.y][p2.x];
float ratio = 1.0f;
if (std::abs(isolevel - val1) < EPS) {
return p1;
} else if (std::abs(isolevel - val2) < EPS) {
return p2;
} else if (std::abs(val1 - val2) < EPS) {
return p1;
}
// interpolate vertex p based on two vertices on the edge
ratio = (isolevel - val1) / (val2 - val1);
return p1 * (1 - ratio) + p2 * ratio;
}
// Get a tuple of global vertex ID from a local edge ID
// Global vertex ID is calculated as (x + dx) + (y + dy) * W + (z + dz) * W *
// H
// Args:
// edge: local edge ID in the cube
// W: width of x dimension
// H: height of y dimension
//
// Returns:
// a pair of global vertex ID
//
std::pair<int, int> GetVPairFromEdge(const int edge, int W, int H) {
const int v1 = _EDGE_TO_VERTICES[edge][0];
const int v2 = _EDGE_TO_VERTICES[edge][1];
const int v1_id = p[v1].x + p[v1].y * W + p[v1].z * W * H;
const int v2_id = p[v2].x + p[v2].y * W + p[v2].z * W * H;
return std::make_pair(v1_id, v2_id);
}
};
// helper functions for hashing
namespace std {
// Taken from boost library to combine several hash functions
template <class T>
inline void hash_combine(size_t& s, const T& v) {
std::hash<T> h;
s ^= h(v) + 0x9e3779b9 + (s << 6) + (s >> 2);
}
// Function for hashing a pair of vertices
template <>
struct hash<std::pair<int, int>> {
size_t operator()(const std::pair<int, int>& p) const {
size_t res = 0;
hash_combine(res, p.first);
hash_combine(res, p.second);
return res;
}
};
} // namespace std

View File

@ -7,8 +7,10 @@
from typing import List, Optional, Tuple
import torch
from pytorch3d import _C
from pytorch3d.ops.marching_cubes_data import EDGE_TO_VERTICES, FACE_TABLE, INDEX
from pytorch3d.transforms import Translate
from torch.autograd import Function
EPS = 0.00001
@ -225,3 +227,71 @@ def marching_cubes_naive(
batched_verts.append([])
batched_faces.append([])
return batched_verts, batched_faces
########################################
# Marching Cubes Implementation in C++
########################################
class _marching_cubes(Function):
"""
Torch Function wrapper for marching_cubes C++ implementation
Backward is not supported.
"""
@staticmethod
def forward(ctx, vol, isolevel):
verts, faces = _C.marching_cubes(vol, isolevel)
return verts, faces
@staticmethod
def backward(ctx, grad_verts, grad_faces):
raise ValueError("marching_cubes backward is not supported")
def marching_cubes(
vol_batch: torch.Tensor,
isolevel: Optional[float] = None,
return_local_coords: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Run marching cubes over a volume scalar field with a designated isolevel.
Returns vertices and faces of the obtained mesh.
This operation is non-differentiable.
Args:
vol_batch: a Tensor of size (N, D, H, W) corresponding to
a batch of 3D scalar fields
isolevel: float used as threshold to determine if a point is inside/outside
the volume. If None, then the average of the maximum and minimum value
of the scalar field is used.
return_local_coords: bool. If True the output vertices will be in local coordinates in
the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range
[0, W-1] x [0, H-1] x [0, D-1]
Returns:
verts: [{V_0}, {V_1}, ...] List of N sets of vertices of shape (|V_i|, 3) in FloatTensor
faces: [{F_0}, {F_1}, ...] List of N sets of faces of shape (|F_i|, 3) in LongTensors
"""
batched_verts, batched_faces = [], []
D, H, W = vol_batch.shape[1:]
for i in range(len(vol_batch)):
vol = vol_batch[i]
thresh = ((vol.max() + vol.min()) / 2).item() if isolevel is None else isolevel
# pyre-fixme[16]: `_marching_cubes` has no attribute `apply`.
verts, faces = _marching_cubes.apply(vol, thresh)
if len(faces) > 0 and len(verts) > 0:
# Convert from world coordinates ([0, D-1], [0, H-1], [0, W-1]) to
# local coordinates in the range [-1, 1]
if return_local_coords:
verts = (
Translate(x=+1.0, y=+1.0, z=+1.0, device=vol.device)
.scale((vol.new_tensor([W, H, D])[None] - 1) * 0.5)
.inverse()
).transform_points(verts[None])[0]
batched_verts.append(verts)
batched_faces.append(faces)
else:
batched_verts.append([])
batched_faces.append([])
return batched_verts, batched_faces

View File

@ -4,19 +4,24 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools
from fvcore.common.benchmark import benchmark
from tests.test_marching_cubes import TestMarchingCubes
def bm_marching_cubes() -> None:
kwargs_list = [
{"batch_size": 1, "V": 5},
{"batch_size": 1, "V": 10},
{"batch_size": 1, "V": 20},
{"batch_size": 1, "V": 40},
{"batch_size": 5, "V": 5},
{"batch_size": 20, "V": 20},
]
case_grid = {
"algo_type": [
"naive",
"cextension",
],
"batch_size": [1, 5, 20],
"V": [5, 10, 20],
}
test_cases = itertools.product(*case_grid.values())
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
benchmark(
TestMarchingCubes.marching_cubes_with_init,
"MARCHING_CUBES",

View File

@ -9,7 +9,7 @@ import pickle
import unittest
import torch
from pytorch3d.ops.marching_cubes import marching_cubes_naive
from pytorch3d.ops.marching_cubes import marching_cubes, marching_cubes_naive
from .common_testing import get_tests_dir, TestCaseMixin
@ -37,6 +37,11 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts, expected_verts)
self.assertClose(faces, expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts, expected_verts)
self.assertClose(faces, expected_faces)
def test_case1(self): # case 1
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 0, 0, 0] = 0
@ -54,11 +59,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case2(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 0:2, 0, 0] = 0
@ -77,11 +92,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case3(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 0, 0, 0] = 0
@ -103,11 +128,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case4(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 1, 0, 0] = 0
@ -129,11 +164,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case5(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 0:2, 0, 0:2] = 0
@ -153,11 +198,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case6(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 1, 0, 0] = 0
@ -184,11 +239,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case7(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 0, 0, 0] = 0
@ -220,11 +285,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case8(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 0, 0, 0] = 0
@ -249,11 +324,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case9(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 1, 0, 0] = 0
@ -278,11 +363,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case10(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 0, 0, 0] = 0
@ -306,11 +401,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case11(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 0, 0, 0] = 0
@ -336,11 +441,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case12(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 1, 0, 0] = 0
@ -368,11 +483,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case13(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 0, 0, 0] = 0
@ -400,11 +525,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
def test_case14(self):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
volume_data[0, 0, 0, 0] = 0
@ -430,11 +565,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 2)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
def test_single_point(self):
@ -468,12 +613,23 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 3)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
# test C++ implementation
verts, faces = marching_cubes(volume_data, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
def test_cube(self):
volume_data = torch.zeros(1, 5, 5, 5) # (B, W, H, D)
volume_data[0, 1, 1, 1] = 1
@ -567,6 +723,11 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume_data, 0.9, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 5)
self.assertClose(verts[0], expected_verts)
@ -575,6 +736,12 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
# Check all values are in the range [-1, 1]
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
# test C++ implementation
verts, faces = marching_cubes(volume_data, 0.9, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
def test_cube_no_duplicate_verts(self):
volume_data = torch.zeros(1, 5, 5, 5) # (B, W, H, D)
volume_data[0, 1, 1, 1] = 1
@ -670,6 +837,11 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
# test C++ implementation
verts, faces = marching_cubes(volume, 64, return_local_coords=False)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(
volume, isolevel=64, return_local_coords=True
)
@ -681,6 +853,12 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
# Check all values are in the range [-1, 1]
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
# test C++ implementation
verts, faces = marching_cubes(volume, 64, return_local_coords=True)
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
# Uses skimage.draw.ellipsoid
def test_double_ellipsoid(self):
if USE_SCIKIT:
@ -694,6 +872,7 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
volume = torch.Tensor(ellip_double).unsqueeze(0)
volume = volume.permute(0, 3, 2, 1) # (B, D, H, W)
verts, faces = marching_cubes_naive(volume, isolevel=0.001)
verts2, faces2 = marching_cubes(volume, isolevel=0.001)
data_filename = "test_marching_cubes_data/double_ellipsoid.pickle"
filename = os.path.join(DATA_DIR, data_filename)
@ -704,6 +883,8 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces)
self.assertClose(verts2[0], expected_verts)
self.assertClose(faces2[0], expected_faces)
def test_cube_surface_area(self):
if USE_SCIKIT:
@ -720,12 +901,15 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
volume_data[0, 2, 2, 2] = 1
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
verts_c, faces_c = marching_cubes(volume_data, return_local_coords=False)
verts_sci, faces_sci = marching_cubes_classic(volume_data[0])
surf = mesh_surface_area(verts[0], faces[0])
surf_c = mesh_surface_area(verts_c[0], faces_c[0])
surf_sci = mesh_surface_area(verts_sci, faces_sci)
self.assertClose(surf, surf_sci)
self.assertClose(surf, surf_c)
def test_sphere_surface_area(self):
if USE_SCIKIT:
@ -746,12 +930,15 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
).unsqueeze(0)
volume = volume.permute(0, 3, 2, 1) # (B, D, H, W)
verts, faces = marching_cubes_naive(volume, isolevel=64)
verts_c, faces_c = marching_cubes(volume, isolevel=64)
verts_sci, faces_sci = marching_cubes_classic(volume[0], level=64)
surf = mesh_surface_area(verts[0], faces[0])
surf_c = mesh_surface_area(verts_c[0], faces_c[0])
surf_sci = mesh_surface_area(verts_sci, faces_sci)
self.assertClose(surf, surf_sci)
self.assertClose(surf, surf_c)
def test_double_ellipsoid_surface_area(self):
if USE_SCIKIT:
@ -766,12 +953,15 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
volume = torch.Tensor(ellip_double).unsqueeze(0)
volume = volume.permute(0, 3, 2, 1) # (B, D, H, W)
verts, faces = marching_cubes_naive(volume, isolevel=0)
verts_c, faces_c = marching_cubes(volume, isolevel=0)
verts_sci, faces_sci = marching_cubes_classic(volume[0], level=0)
surf = mesh_surface_area(verts[0], faces[0])
surf_c = mesh_surface_area(verts_c[0], faces_c[0])
surf_sci = mesh_surface_area(verts_sci, faces_sci)
self.assertClose(surf, surf_sci)
self.assertClose(surf, surf_c)
def test_ball_example(self):
N = 15
@ -780,6 +970,9 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
u = (X - 15) ** 2 + (Y - 15) ** 2 + (Z - 15) ** 2 - 8**2
u = u[None].float()
verts, faces = marching_cubes_naive(u, 0, return_local_coords=False)
verts2, faces2 = marching_cubes(u, 0, return_local_coords=False)
self.assertClose(verts[0], verts2[0])
self.assertClose(faces[0], faces2[0])
@staticmethod
def marching_cubes_with_init(algo_type: str, batch_size: int, V: int):
@ -789,6 +982,7 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
)
algo_table = {
"naive": marching_cubes_naive,
"cextension": marching_cubes,
}
def convert():