diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 44cfddfc..6a17dbb0 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -22,6 +22,7 @@ #include "interp_face_attrs/interp_face_attrs.h" #include "iou_box3d/iou_box3d.h" #include "knn/knn.h" +#include "marching_cubes/marching_cubes.h" #include "mesh_normal_consistency/mesh_normal_consistency.h" #include "packed_to_padded_tensor/packed_to_padded_tensor.h" #include "point_mesh/point_mesh_cuda.h" @@ -94,6 +95,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // 3D IoU m.def("iou_box3d", &IoUBox3D); + // Marching cubes + m.def("marching_cubes", &MarchingCubes); + // Pulsar. #ifdef PULSAR_LOGGING_ENABLED c10::ShowLogInfoToStderr(); diff --git a/pytorch3d/csrc/marching_cubes/marching_cubes.h b/pytorch3d/csrc/marching_cubes/marching_cubes.h new file mode 100644 index 00000000..5984b0b6 --- /dev/null +++ b/pytorch3d/csrc/marching_cubes/marching_cubes.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include "utils/pytorch3d_cutils.h" + +// Run Marching Cubes algorithm over a batch of volume scalar fields +// with a pre-defined threshold and return a mesh composed of vertices +// and faces for the mesh. +// +// Args: +// vol: FloatTensor of shape (D, H, W) giving a volume +// scalar grids. +// isolevel: isosurface value to use as the threshoold to determine whether +// the points are within a volume. +// +// Returns: +// vertices: List of N FloatTensors of vertices +// faces: List of N LongTensors of faces + +// CPU implementation +std::tuple MarchingCubesCpu( + const at::Tensor& vol, + const float isolevel); + +// Implementation which is exposed +inline std::tuple MarchingCubes( + const at::Tensor& vol, + const float isolevel) { + return MarchingCubesCpu(vol.contiguous(), isolevel); +} diff --git a/pytorch3d/csrc/marching_cubes/marching_cubes_cpu.cpp b/pytorch3d/csrc/marching_cubes/marching_cubes_cpu.cpp new file mode 100644 index 00000000..3c801001 --- /dev/null +++ b/pytorch3d/csrc/marching_cubes/marching_cubes_cpu.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include "marching_cubes/marching_cubes_utils.h" + +// Cpu implementation for Marching Cubes +// Args: +// vol: a Tensor of size (D, H, W) corresponding to a 3D scalar field +// isolevel: the isosurface value to use as the threshold to determine +// whether points are within a volume. +// +// Returns: +// vertices: a float tensor of shape (N, 3) for positions of the mesh +// faces: a long tensor of shape (N, 3) for indices of the face vertices +// +std::tuple MarchingCubesCpu( + const at::Tensor& vol, + const float isolevel) { + // volume shapes + const int D = vol.size(0); + const int H = vol.size(1); + const int W = vol.size(2); + + // Create tensor accessors + auto vol_a = vol.accessor(); + // vpair_to_edge maps a pair of vertex ids to its corresponding edge id + std::unordered_map, int64_t> vpair_to_edge; + // edge_id_to_v maps from an edge id to a vertex position + std::unordered_map edge_id_to_v; + // uniq_edge_id: used to remove redundant edge ids + std::unordered_map uniq_edge_id; + std::vector faces; // store face indices + std::vector verts; // store vertex positions + // enumerate each cell in the 3d grid + for (int z = 0; z < D - 1; z++) { + for (int y = 0; y < H - 1; y++) { + for (int x = 0; x < W - 1; x++) { + Cube cube(x, y, z, vol_a, isolevel); + // Cube is entirely in/out of the surface + if (_FACE_TABLE[cube.cubeindex][0] == -1) { + continue; + } + // store all boundary vertices that intersect with the edges + std::array interp_points; + // triangle vertex IDs and positions + std::vector tri; + std::vector ps; + + // Interpolate the vertices where the surface intersects with the cube + for (int j = 0; _FACE_TABLE[cube.cubeindex][j] != -1; j++) { + const int e = _FACE_TABLE[cube.cubeindex][j]; + interp_points[e] = cube.VertexInterp(isolevel, e, vol_a); + + auto vpair = cube.GetVPairFromEdge(e, W, H); + if (!vpair_to_edge.count(vpair)) { + vpair_to_edge[vpair] = vpair_to_edge.size(); + } + + int64_t edge = vpair_to_edge[vpair]; + tri.push_back(edge); + ps.push_back(interp_points[e]); + + // Check if the triangle face is degenerate. A triangle face + // is degenerate if any of the two verices share the same 3D position + if ((j + 1) % 3 == 0 && ps[0] != ps[1] && ps[1] != ps[2] && + ps[2] != ps[0]) { + for (int k = 0; k < 3; k++) { + int v = tri[k]; + edge_id_to_v[tri.at(k)] = ps.at(k); + if (!uniq_edge_id.count(v)) { + uniq_edge_id[v] = verts.size(); + verts.push_back(edge_id_to_v[v]); + } + faces.push_back(uniq_edge_id[v]); + } + tri.clear(); + ps.clear(); + } + } // endif + } // endfor x + } // endfor y + } // endfor z + // Collect returning tensor + const int n_vertices = verts.size(); + const int64_t n_faces = (int64_t)faces.size() / 3; + auto vert_tensor = torch::zeros({n_vertices, 3}, torch::kFloat); + auto face_tensor = torch::zeros({n_faces, 3}, torch::kInt64); + + auto vert_a = vert_tensor.accessor(); + for (int i = 0; i < n_vertices; i++) { + vert_a[i][0] = verts.at(i).x; + vert_a[i][1] = verts.at(i).y; + vert_a[i][2] = verts.at(i).z; + } + + auto face_a = face_tensor.accessor(); + for (int64_t i = 0; i < n_faces; i++) { + face_a[i][0] = faces.at(i * 3 + 0); + face_a[i][1] = faces.at(i * 3 + 1); + face_a[i][2] = faces.at(i * 3 + 2); + } + + return std::make_tuple(vert_tensor, face_tensor); +} diff --git a/pytorch3d/csrc/marching_cubes/marching_cubes_utils.h b/pytorch3d/csrc/marching_cubes/marching_cubes_utils.h new file mode 100644 index 00000000..7d417a3a --- /dev/null +++ b/pytorch3d/csrc/marching_cubes/marching_cubes_utils.h @@ -0,0 +1,443 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include +#include "ATen/core/TensorAccessor.h" + +// EPS: Used to assess whether two float values are close +const float EPS = 1e-5; + +// A table mapping from cubeindex to a list of face configurations. +// Each list contains at most 5 faces, where each face is represented with +// 3 consecutive numbers +// Table taken from http://paulbourke.net/geometry/polygonise/ +const int _FACE_TABLE[256][16] = { + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 8, 3, 9, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 8, 3, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {9, 2, 10, 0, 2, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {2, 8, 3, 2, 10, 8, 10, 9, 8, -1, -1, -1, -1, -1, -1, -1}, + {3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 11, 2, 8, 11, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 9, 0, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 11, 2, 1, 9, 11, 9, 8, 11, -1, -1, -1, -1, -1, -1, -1}, + {3, 10, 1, 11, 10, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 10, 1, 0, 8, 10, 8, 11, 10, -1, -1, -1, -1, -1, -1, -1}, + {3, 9, 0, 3, 11, 9, 11, 10, 9, -1, -1, -1, -1, -1, -1, -1}, + {9, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {4, 3, 0, 7, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 1, 9, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {4, 1, 9, 4, 7, 1, 7, 3, 1, -1, -1, -1, -1, -1, -1, -1}, + {1, 2, 10, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {3, 4, 7, 3, 0, 4, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1}, + {9, 2, 10, 9, 0, 2, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1}, + {2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, -1, -1, -1, -1}, + {8, 4, 7, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {11, 4, 7, 11, 2, 4, 2, 0, 4, -1, -1, -1, -1, -1, -1, -1}, + {9, 0, 1, 8, 4, 7, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1}, + {4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, -1, -1, -1, -1}, + {3, 10, 1, 3, 11, 10, 7, 8, 4, -1, -1, -1, -1, -1, -1, -1}, + {1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, -1, -1, -1, -1}, + {4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, -1, -1, -1, -1}, + {4, 7, 11, 4, 11, 9, 9, 11, 10, -1, -1, -1, -1, -1, -1, -1}, + {9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {9, 5, 4, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 5, 4, 1, 5, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {8, 5, 4, 8, 3, 5, 3, 1, 5, -1, -1, -1, -1, -1, -1, -1}, + {1, 2, 10, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {3, 0, 8, 1, 2, 10, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1}, + {5, 2, 10, 5, 4, 2, 4, 0, 2, -1, -1, -1, -1, -1, -1, -1}, + {2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, -1, -1, -1, -1}, + {9, 5, 4, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 11, 2, 0, 8, 11, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1}, + {0, 5, 4, 0, 1, 5, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1}, + {2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, -1, -1, -1, -1}, + {10, 3, 11, 10, 1, 3, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1}, + {4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, -1, -1, -1, -1}, + {5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, -1, -1, -1, -1}, + {5, 4, 8, 5, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1}, + {9, 7, 8, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {9, 3, 0, 9, 5, 3, 5, 7, 3, -1, -1, -1, -1, -1, -1, -1}, + {0, 7, 8, 0, 1, 7, 1, 5, 7, -1, -1, -1, -1, -1, -1, -1}, + {1, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {9, 7, 8, 9, 5, 7, 10, 1, 2, -1, -1, -1, -1, -1, -1, -1}, + {10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, -1, -1, -1, -1}, + {8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, -1, -1, -1, -1}, + {2, 10, 5, 2, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1}, + {7, 9, 5, 7, 8, 9, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1}, + {9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, -1, -1, -1, -1}, + {2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, -1, -1, -1, -1}, + {11, 2, 1, 11, 1, 7, 7, 1, 5, -1, -1, -1, -1, -1, -1, -1}, + {9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, -1, -1, -1, -1}, + {5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, -1}, + {11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, -1}, + {11, 10, 5, 7, 11, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 8, 3, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {9, 0, 1, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 8, 3, 1, 9, 8, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1}, + {1, 6, 5, 2, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 6, 5, 1, 2, 6, 3, 0, 8, -1, -1, -1, -1, -1, -1, -1}, + {9, 6, 5, 9, 0, 6, 0, 2, 6, -1, -1, -1, -1, -1, -1, -1}, + {5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, -1, -1, -1, -1}, + {2, 3, 11, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {11, 0, 8, 11, 2, 0, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1}, + {0, 1, 9, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1}, + {5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, -1, -1, -1, -1}, + {6, 3, 11, 6, 5, 3, 5, 1, 3, -1, -1, -1, -1, -1, -1, -1}, + {0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, -1, -1, -1, -1}, + {3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, -1, -1, -1, -1}, + {6, 5, 9, 6, 9, 11, 11, 9, 8, -1, -1, -1, -1, -1, -1, -1}, + {5, 10, 6, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {4, 3, 0, 4, 7, 3, 6, 5, 10, -1, -1, -1, -1, -1, -1, -1}, + {1, 9, 0, 5, 10, 6, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1}, + {10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, -1, -1, -1, -1}, + {6, 1, 2, 6, 5, 1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1}, + {1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, -1, -1, -1, -1}, + {8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, -1, -1, -1, -1}, + {7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, -1}, + {3, 11, 2, 7, 8, 4, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1}, + {5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, -1, -1, -1, -1}, + {0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1}, + {9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, -1}, + {8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, -1, -1, -1, -1}, + {5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, -1}, + {0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, -1}, + {6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, -1, -1, -1, -1}, + {10, 4, 9, 6, 4, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {4, 10, 6, 4, 9, 10, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1}, + {10, 0, 1, 10, 6, 0, 6, 4, 0, -1, -1, -1, -1, -1, -1, -1}, + {8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, -1, -1, -1, -1}, + {1, 4, 9, 1, 2, 4, 2, 6, 4, -1, -1, -1, -1, -1, -1, -1}, + {3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, -1, -1, -1, -1}, + {0, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {8, 3, 2, 8, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1}, + {10, 4, 9, 10, 6, 4, 11, 2, 3, -1, -1, -1, -1, -1, -1, -1}, + {0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, -1, -1, -1, -1}, + {3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, -1, -1, -1, -1}, + {6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, -1}, + {9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, -1, -1, -1, -1}, + {8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, -1}, + {3, 11, 6, 3, 6, 0, 0, 6, 4, -1, -1, -1, -1, -1, -1, -1}, + {6, 4, 8, 11, 6, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {7, 10, 6, 7, 8, 10, 8, 9, 10, -1, -1, -1, -1, -1, -1, -1}, + {0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, -1, -1, -1, -1}, + {10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, -1, -1, -1, -1}, + {10, 6, 7, 10, 7, 1, 1, 7, 3, -1, -1, -1, -1, -1, -1, -1}, + {1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, -1, -1, -1, -1}, + {2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, -1}, + {7, 8, 0, 7, 0, 6, 6, 0, 2, -1, -1, -1, -1, -1, -1, -1}, + {7, 3, 2, 6, 7, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, -1, -1, -1, -1}, + {2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, -1}, + {1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, -1}, + {11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, -1, -1, -1, -1}, + {8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, -1}, + {0, 9, 1, 11, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, -1, -1, -1, -1}, + {7, 11, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {3, 0, 8, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 1, 9, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {8, 1, 9, 8, 3, 1, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1}, + {10, 1, 2, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 2, 10, 3, 0, 8, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1}, + {2, 9, 0, 2, 10, 9, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1}, + {6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, -1, -1, -1, -1}, + {7, 2, 3, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {7, 0, 8, 7, 6, 0, 6, 2, 0, -1, -1, -1, -1, -1, -1, -1}, + {2, 7, 6, 2, 3, 7, 0, 1, 9, -1, -1, -1, -1, -1, -1, -1}, + {1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, -1, -1, -1, -1}, + {10, 7, 6, 10, 1, 7, 1, 3, 7, -1, -1, -1, -1, -1, -1, -1}, + {10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, -1, -1, -1, -1}, + {0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, -1, -1, -1, -1}, + {7, 6, 10, 7, 10, 8, 8, 10, 9, -1, -1, -1, -1, -1, -1, -1}, + {6, 8, 4, 11, 8, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {3, 6, 11, 3, 0, 6, 0, 4, 6, -1, -1, -1, -1, -1, -1, -1}, + {8, 6, 11, 8, 4, 6, 9, 0, 1, -1, -1, -1, -1, -1, -1, -1}, + {9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, -1, -1, -1, -1}, + {6, 8, 4, 6, 11, 8, 2, 10, 1, -1, -1, -1, -1, -1, -1, -1}, + {1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, -1, -1, -1, -1}, + {4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, -1, -1, -1, -1}, + {10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, -1}, + {8, 2, 3, 8, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1}, + {0, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, -1, -1, -1, -1}, + {1, 9, 4, 1, 4, 2, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1}, + {8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, -1, -1, -1, -1}, + {10, 1, 0, 10, 0, 6, 6, 0, 4, -1, -1, -1, -1, -1, -1, -1}, + {4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, -1}, + {10, 9, 4, 6, 10, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {4, 9, 5, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 8, 3, 4, 9, 5, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1}, + {5, 0, 1, 5, 4, 0, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1}, + {11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, -1, -1, -1, -1}, + {9, 5, 4, 10, 1, 2, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1}, + {6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, -1, -1, -1, -1}, + {7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, -1, -1, -1, -1}, + {3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, -1}, + {7, 2, 3, 7, 6, 2, 5, 4, 9, -1, -1, -1, -1, -1, -1, -1}, + {9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, -1, -1, -1, -1}, + {3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, -1, -1, -1, -1}, + {6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, -1}, + {9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, -1, -1, -1, -1}, + {1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, -1}, + {4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, -1}, + {7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, -1, -1, -1, -1}, + {6, 9, 5, 6, 11, 9, 11, 8, 9, -1, -1, -1, -1, -1, -1, -1}, + {3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, -1, -1, -1, -1}, + {0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, -1, -1, -1, -1}, + {6, 11, 3, 6, 3, 5, 5, 3, 1, -1, -1, -1, -1, -1, -1, -1}, + {1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, -1, -1, -1, -1}, + {0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, -1}, + {11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, -1}, + {6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, -1, -1, -1, -1}, + {5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, -1, -1, -1, -1}, + {9, 5, 6, 9, 6, 0, 0, 6, 2, -1, -1, -1, -1, -1, -1, -1}, + {1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, -1}, + {1, 5, 6, 2, 1, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, -1}, + {10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, -1, -1, -1, -1}, + {0, 3, 8, 5, 6, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {10, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {11, 5, 10, 7, 5, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {11, 5, 10, 11, 7, 5, 8, 3, 0, -1, -1, -1, -1, -1, -1, -1}, + {5, 11, 7, 5, 10, 11, 1, 9, 0, -1, -1, -1, -1, -1, -1, -1}, + {10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, -1, -1, -1, -1}, + {11, 1, 2, 11, 7, 1, 7, 5, 1, -1, -1, -1, -1, -1, -1, -1}, + {0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, -1, -1, -1, -1}, + {9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, -1, -1, -1, -1}, + {7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, -1}, + {2, 5, 10, 2, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1}, + {8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, -1, -1, -1, -1}, + {9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, -1, -1, -1, -1}, + {9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, -1}, + {1, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 8, 7, 0, 7, 1, 1, 7, 5, -1, -1, -1, -1, -1, -1, -1}, + {9, 0, 3, 9, 3, 5, 5, 3, 7, -1, -1, -1, -1, -1, -1, -1}, + {9, 8, 7, 5, 9, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {5, 8, 4, 5, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1}, + {5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, -1, -1, -1, -1}, + {0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, -1, -1, -1, -1}, + {10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, -1}, + {2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, -1, -1, -1, -1}, + {0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, -1}, + {0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, -1}, + {9, 4, 5, 2, 11, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, -1, -1, -1, -1}, + {5, 10, 2, 5, 2, 4, 4, 2, 0, -1, -1, -1, -1, -1, -1, -1}, + {3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, -1}, + {5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, -1, -1, -1, -1}, + {8, 4, 5, 8, 5, 3, 3, 5, 1, -1, -1, -1, -1, -1, -1, -1}, + {0, 4, 5, 1, 0, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, -1, -1, -1, -1}, + {9, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {4, 11, 7, 4, 9, 11, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1}, + {0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, -1, -1, -1, -1}, + {1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, -1, -1, -1, -1}, + {3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, -1}, + {4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, -1, -1, -1, -1}, + {9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, -1}, + {11, 7, 4, 11, 4, 2, 2, 4, 0, -1, -1, -1, -1, -1, -1, -1}, + {11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, -1, -1, -1, -1}, + {2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, -1, -1, -1, -1}, + {9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, -1}, + {3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, -1}, + {1, 10, 2, 8, 7, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {4, 9, 1, 4, 1, 7, 7, 1, 3, -1, -1, -1, -1, -1, -1, -1}, + {4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, -1, -1, -1, -1}, + {4, 0, 3, 7, 4, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {4, 8, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {9, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {3, 0, 9, 3, 9, 11, 11, 9, 10, -1, -1, -1, -1, -1, -1, -1}, + {0, 1, 10, 0, 10, 8, 8, 10, 11, -1, -1, -1, -1, -1, -1, -1}, + {3, 1, 10, 11, 3, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 2, 11, 1, 11, 9, 9, 11, 8, -1, -1, -1, -1, -1, -1, -1}, + {3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, -1, -1, -1, -1}, + {0, 2, 11, 8, 0, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {2, 3, 8, 2, 8, 10, 10, 8, 9, -1, -1, -1, -1, -1, -1, -1}, + {9, 10, 2, 0, 9, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, -1, -1, -1, -1}, + {1, 10, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 3, 8, 9, 1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {0, 3, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}}; + +// Table mapping each edge to the corresponding cube vertices +const int _EDGE_TO_VERTICES[12][2] = { + {0, 1}, + {1, 5}, + {4, 5}, + {0, 4}, + {2, 3}, + {3, 7}, + {6, 7}, + {2, 6}, + {0, 2}, + {1, 3}, + {5, 7}, + {4, 6}, +}; + +// Table mapping from 0-7 to v0-v7 in cube.vertices +const int _INDEX_TABLE[8] = {0, 1, 5, 4, 2, 3, 7, 6}; + +// Data structures for the marching cubes +struct Vertex { + // Constructor used when performing marching cube in each cell + explicit Vertex(float x = 0.0f, float y = 0.0f, float z = 0.0f) + : x(x), y(y), z(z) {} + + // The */+ operator overrides are used for vertex interpolation + Vertex operator*(float s) const { + return Vertex(x * s, y * s, z * s); + } + Vertex operator+(const Vertex& xyz) const { + return Vertex(x + xyz.x, y + xyz.y, z + xyz.z); + } + // The == operator overrides is used for checking degenerate triangles + bool operator==(const Vertex& xyz) const { + if (std::abs(x - xyz.x) < EPS && std::abs(y - xyz.y) < EPS && + std::abs(z - xyz.z) < EPS) { + return true; + } + return false; + } + // vertex position + float x, y, z; +}; + +struct Cube { + // Edge and vertex convention: + // v4_______e4____________v5 + // /| /| + // / | / | + // e7/ | e5/ | + // /___|______e6_________/ | + // v7| | |v6 |e9 + // | | | | + // | |e8 |e10| + // e11| | | | + // | |_________________|___| + // | / v0 e0 | /v1 + // | / | / + // | /e3 | /e1 + // |/_____________________|/ + // v3 e2 v2 + + Vertex p[8]; + int x, y, z; + int cubeindex = 0; + Cube( + int x, + int y, + int z, + const at::TensorAccessor& vol_a, + const float isolevel) + : x(x), y(y), z(z) { + // vertex position (x, y, z) for v0-v1-v4-v5-v3-v2-v7-v6 + for (int v = 0; v < 8; v++) { + p[v] = Vertex(x + (v & 1), y + (v >> 1 & 1), z + (v >> 2 & 1)); + } + // Calculates cube configuration index given values of the cube vertices + for (int i = 0; i < 8; i++) { + const int idx = _INDEX_TABLE[i]; + Vertex v = p[idx]; + if (vol_a[v.z][v.y][v.x] < isolevel) { + cubeindex |= (1 << i); + } + } + } + + // Linearly interpolate the position where an isosurface cuts an edge + // between two vertices, based on their scalar values + // + // Args: + // isolevel: float value used as threshold + // edge: edge (ID) to interpolate + // cube: current cube vertices + // vol_a: 3D scalar field + // + // Returns: + // point: interpolated vertex + Vertex VertexInterp( + float isolevel, + const int edge, + const at::TensorAccessor& vol_a) { + const int v1 = _EDGE_TO_VERTICES[edge][0]; + const int v2 = _EDGE_TO_VERTICES[edge][1]; + Vertex p1 = p[v1]; + Vertex p2 = p[v2]; + float val1 = vol_a[p1.z][p1.y][p1.x]; + float val2 = vol_a[p2.z][p2.y][p2.x]; + + float ratio = 1.0f; + if (std::abs(isolevel - val1) < EPS) { + return p1; + } else if (std::abs(isolevel - val2) < EPS) { + return p2; + } else if (std::abs(val1 - val2) < EPS) { + return p1; + } + // interpolate vertex p based on two vertices on the edge + ratio = (isolevel - val1) / (val2 - val1); + return p1 * (1 - ratio) + p2 * ratio; + } + + // Get a tuple of global vertex ID from a local edge ID + // Global vertex ID is calculated as (x + dx) + (y + dy) * W + (z + dz) * W * + // H + + // Args: + // edge: local edge ID in the cube + // W: width of x dimension + // H: height of y dimension + // + // Returns: + // a pair of global vertex ID + // + std::pair GetVPairFromEdge(const int edge, int W, int H) { + const int v1 = _EDGE_TO_VERTICES[edge][0]; + const int v2 = _EDGE_TO_VERTICES[edge][1]; + const int v1_id = p[v1].x + p[v1].y * W + p[v1].z * W * H; + const int v2_id = p[v2].x + p[v2].y * W + p[v2].z * W * H; + return std::make_pair(v1_id, v2_id); + } +}; + +// helper functions for hashing +namespace std { +// Taken from boost library to combine several hash functions +template +inline void hash_combine(size_t& s, const T& v) { + std::hash h; + s ^= h(v) + 0x9e3779b9 + (s << 6) + (s >> 2); +} + +// Function for hashing a pair of vertices +template <> +struct hash> { + size_t operator()(const std::pair& p) const { + size_t res = 0; + hash_combine(res, p.first); + hash_combine(res, p.second); + return res; + } +}; + +} // namespace std diff --git a/pytorch3d/ops/marching_cubes.py b/pytorch3d/ops/marching_cubes.py index e3d621db..98d72ce7 100644 --- a/pytorch3d/ops/marching_cubes.py +++ b/pytorch3d/ops/marching_cubes.py @@ -7,8 +7,10 @@ from typing import List, Optional, Tuple import torch +from pytorch3d import _C from pytorch3d.ops.marching_cubes_data import EDGE_TO_VERTICES, FACE_TABLE, INDEX from pytorch3d.transforms import Translate +from torch.autograd import Function EPS = 0.00001 @@ -225,3 +227,71 @@ def marching_cubes_naive( batched_verts.append([]) batched_faces.append([]) return batched_verts, batched_faces + + +######################################## +# Marching Cubes Implementation in C++ +######################################## +class _marching_cubes(Function): + """ + Torch Function wrapper for marching_cubes C++ implementation + Backward is not supported. + """ + + @staticmethod + def forward(ctx, vol, isolevel): + verts, faces = _C.marching_cubes(vol, isolevel) + return verts, faces + + @staticmethod + def backward(ctx, grad_verts, grad_faces): + raise ValueError("marching_cubes backward is not supported") + + +def marching_cubes( + vol_batch: torch.Tensor, + isolevel: Optional[float] = None, + return_local_coords: bool = True, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Run marching cubes over a volume scalar field with a designated isolevel. + Returns vertices and faces of the obtained mesh. + This operation is non-differentiable. + + Args: + vol_batch: a Tensor of size (N, D, H, W) corresponding to + a batch of 3D scalar fields + isolevel: float used as threshold to determine if a point is inside/outside + the volume. If None, then the average of the maximum and minimum value + of the scalar field is used. + return_local_coords: bool. If True the output vertices will be in local coordinates in + the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range + [0, W-1] x [0, H-1] x [0, D-1] + + + Returns: + verts: [{V_0}, {V_1}, ...] List of N sets of vertices of shape (|V_i|, 3) in FloatTensor + faces: [{F_0}, {F_1}, ...] List of N sets of faces of shape (|F_i|, 3) in LongTensors + """ + batched_verts, batched_faces = [], [] + D, H, W = vol_batch.shape[1:] + for i in range(len(vol_batch)): + vol = vol_batch[i] + thresh = ((vol.max() + vol.min()) / 2).item() if isolevel is None else isolevel + # pyre-fixme[16]: `_marching_cubes` has no attribute `apply`. + verts, faces = _marching_cubes.apply(vol, thresh) + if len(faces) > 0 and len(verts) > 0: + # Convert from world coordinates ([0, D-1], [0, H-1], [0, W-1]) to + # local coordinates in the range [-1, 1] + if return_local_coords: + verts = ( + Translate(x=+1.0, y=+1.0, z=+1.0, device=vol.device) + .scale((vol.new_tensor([W, H, D])[None] - 1) * 0.5) + .inverse() + ).transform_points(verts[None])[0] + batched_verts.append(verts) + batched_faces.append(faces) + else: + batched_verts.append([]) + batched_faces.append([]) + return batched_verts, batched_faces diff --git a/tests/benchmarks/bm_marching_cubes.py b/tests/benchmarks/bm_marching_cubes.py index 1ee7bc4d..c212b561 100644 --- a/tests/benchmarks/bm_marching_cubes.py +++ b/tests/benchmarks/bm_marching_cubes.py @@ -4,19 +4,24 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import itertools + from fvcore.common.benchmark import benchmark from tests.test_marching_cubes import TestMarchingCubes def bm_marching_cubes() -> None: - kwargs_list = [ - {"batch_size": 1, "V": 5}, - {"batch_size": 1, "V": 10}, - {"batch_size": 1, "V": 20}, - {"batch_size": 1, "V": 40}, - {"batch_size": 5, "V": 5}, - {"batch_size": 20, "V": 20}, - ] + case_grid = { + "algo_type": [ + "naive", + "cextension", + ], + "batch_size": [1, 5, 20], + "V": [5, 10, 20], + } + test_cases = itertools.product(*case_grid.values()) + kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases] + benchmark( TestMarchingCubes.marching_cubes_with_init, "MARCHING_CUBES", diff --git a/tests/test_marching_cubes.py b/tests/test_marching_cubes.py index 06442e94..50026b8d 100644 --- a/tests/test_marching_cubes.py +++ b/tests/test_marching_cubes.py @@ -9,7 +9,7 @@ import pickle import unittest import torch -from pytorch3d.ops.marching_cubes import marching_cubes_naive +from pytorch3d.ops.marching_cubes import marching_cubes, marching_cubes_naive from .common_testing import get_tests_dir, TestCaseMixin @@ -37,6 +37,11 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts, expected_verts) self.assertClose(faces, expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts, expected_verts) + self.assertClose(faces, expected_faces) + def test_case1(self): # case 1 volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 0, 0, 0] = 0 @@ -54,11 +59,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case2(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 0:2, 0, 0] = 0 @@ -77,11 +92,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case3(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 0, 0, 0] = 0 @@ -103,11 +128,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case4(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 1, 0, 0] = 0 @@ -129,11 +164,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case5(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 0:2, 0, 0:2] = 0 @@ -153,11 +198,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case6(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 1, 0, 0] = 0 @@ -184,11 +239,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case7(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 0, 0, 0] = 0 @@ -220,11 +285,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case8(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 0, 0, 0] = 0 @@ -249,11 +324,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case9(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 1, 0, 0] = 0 @@ -278,11 +363,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case10(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 0, 0, 0] = 0 @@ -306,11 +401,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case11(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 0, 0, 0] = 0 @@ -336,11 +441,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case12(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 1, 0, 0] = 0 @@ -368,11 +483,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case13(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 0, 0, 0] = 0 @@ -400,11 +525,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + def test_case14(self): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data[0, 0, 0, 0] = 0 @@ -430,11 +565,21 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 2) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + class TestMarchingCubes(TestCaseMixin, unittest.TestCase): def test_single_point(self): @@ -468,12 +613,23 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 3) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all()) + # test C++ implementation + verts, faces = marching_cubes(volume_data, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all()) + def test_cube(self): volume_data = torch.zeros(1, 5, 5, 5) # (B, W, H, D) volume_data[0, 1, 1, 1] = 1 @@ -567,6 +723,11 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume_data, 0.9, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 5) self.assertClose(verts[0], expected_verts) @@ -575,6 +736,12 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): # Check all values are in the range [-1, 1] self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all()) + # test C++ implementation + verts, faces = marching_cubes(volume_data, 0.9, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all()) + def test_cube_no_duplicate_verts(self): volume_data = torch.zeros(1, 5, 5, 5) # (B, W, H, D) volume_data[0, 1, 1, 1] = 1 @@ -670,6 +837,11 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + # test C++ implementation + verts, faces = marching_cubes(volume, 64, return_local_coords=False) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive( volume, isolevel=64, return_local_coords=True ) @@ -681,6 +853,12 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): # Check all values are in the range [-1, 1] self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all()) + # test C++ implementation + verts, faces = marching_cubes(volume, 64, return_local_coords=True) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all()) + # Uses skimage.draw.ellipsoid def test_double_ellipsoid(self): if USE_SCIKIT: @@ -694,6 +872,7 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): volume = torch.Tensor(ellip_double).unsqueeze(0) volume = volume.permute(0, 3, 2, 1) # (B, D, H, W) verts, faces = marching_cubes_naive(volume, isolevel=0.001) + verts2, faces2 = marching_cubes(volume, isolevel=0.001) data_filename = "test_marching_cubes_data/double_ellipsoid.pickle" filename = os.path.join(DATA_DIR, data_filename) @@ -704,6 +883,8 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + self.assertClose(verts2[0], expected_verts) + self.assertClose(faces2[0], expected_faces) def test_cube_surface_area(self): if USE_SCIKIT: @@ -720,12 +901,15 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): volume_data[0, 2, 2, 2] = 1 volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + verts_c, faces_c = marching_cubes(volume_data, return_local_coords=False) verts_sci, faces_sci = marching_cubes_classic(volume_data[0]) surf = mesh_surface_area(verts[0], faces[0]) + surf_c = mesh_surface_area(verts_c[0], faces_c[0]) surf_sci = mesh_surface_area(verts_sci, faces_sci) self.assertClose(surf, surf_sci) + self.assertClose(surf, surf_c) def test_sphere_surface_area(self): if USE_SCIKIT: @@ -746,12 +930,15 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): ).unsqueeze(0) volume = volume.permute(0, 3, 2, 1) # (B, D, H, W) verts, faces = marching_cubes_naive(volume, isolevel=64) + verts_c, faces_c = marching_cubes(volume, isolevel=64) verts_sci, faces_sci = marching_cubes_classic(volume[0], level=64) surf = mesh_surface_area(verts[0], faces[0]) + surf_c = mesh_surface_area(verts_c[0], faces_c[0]) surf_sci = mesh_surface_area(verts_sci, faces_sci) self.assertClose(surf, surf_sci) + self.assertClose(surf, surf_c) def test_double_ellipsoid_surface_area(self): if USE_SCIKIT: @@ -766,12 +953,15 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): volume = torch.Tensor(ellip_double).unsqueeze(0) volume = volume.permute(0, 3, 2, 1) # (B, D, H, W) verts, faces = marching_cubes_naive(volume, isolevel=0) + verts_c, faces_c = marching_cubes(volume, isolevel=0) verts_sci, faces_sci = marching_cubes_classic(volume[0], level=0) surf = mesh_surface_area(verts[0], faces[0]) + surf_c = mesh_surface_area(verts_c[0], faces_c[0]) surf_sci = mesh_surface_area(verts_sci, faces_sci) self.assertClose(surf, surf_sci) + self.assertClose(surf, surf_c) def test_ball_example(self): N = 15 @@ -780,6 +970,9 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): u = (X - 15) ** 2 + (Y - 15) ** 2 + (Z - 15) ** 2 - 8**2 u = u[None].float() verts, faces = marching_cubes_naive(u, 0, return_local_coords=False) + verts2, faces2 = marching_cubes(u, 0, return_local_coords=False) + self.assertClose(verts[0], verts2[0]) + self.assertClose(faces[0], faces2[0]) @staticmethod def marching_cubes_with_init(algo_type: str, batch_size: int, V: int): @@ -789,6 +982,7 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): ) algo_table = { "naive": marching_cubes_naive, + "cextension": marching_cubes, } def convert():