diff --git a/pytorch3d/ops/marching_cubes.py b/pytorch3d/ops/marching_cubes.py index 1747667a..e3d621db 100644 --- a/pytorch3d/ops/marching_cubes.py +++ b/pytorch3d/ops/marching_cubes.py @@ -4,10 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch -from pytorch3d.ops.marching_cubes_data import EDGE_TABLE, EDGE_TO_VERTICES, FACE_TABLE +from pytorch3d.ops.marching_cubes_data import EDGE_TO_VERTICES, FACE_TABLE, INDEX from pytorch3d.transforms import Translate @@ -15,10 +15,15 @@ EPS = 0.00001 class Cube: - def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1) -> None: + def __init__( + self, + bfl_v: Tuple[int, int, int], + volume: torch.Tensor, + isolevel: float, + ) -> None: """ Initializes a cube given the bottom front left vertex coordinate - and the cube spacing + and computes the cube configuration given vertex values and isolevel. Edge and vertex convention: @@ -31,8 +36,8 @@ class Cube: | | | | | |e8 |e10| e11| | | | - | |_________________|___| - | / v0 e0 | /v1 + | |______e0_________|___| + | / v0(bfl_v) | |v1 | / | / | /e3 | /e1 |/_____________________|/ @@ -41,311 +46,182 @@ class Cube: Args: bfl_vertex: a tuple of size 3 corresponding to the bottom front left vertex of the cube in (x, y, z) format - spacing: the length of each edge of the cube + volume: the 3D scalar data + isolevel: the isosurface value used as a threshold for determining whether a point + is inside/outside the volume """ - # match corner orders to algorithm convention - if len(bfl_vertex) != 3: - msg = "The vertex {} is size {} instead of size 3".format( - bfl_vertex, len(bfl_vertex) - ) - raise ValueError(msg) + x, y, z = bfl_v + self.x, self.y, self.z = x, y, z + self.bfl_v = bfl_v + self.verts = [ + [x + (v & 1), y + (v >> 1 & 1), z + (v >> 2 & 1)] for v in range(8) + ] # vertex position (x, y, z) for v0-v1-v4-v5-v3-v2-v7-v6 - x, y, z = bfl_vertex - self.vertices = torch.tensor( - [ - [x, y, z + spacing], - [x + spacing, y, z + spacing], - [x + spacing, y, z], - [x, y, z], - [x, y + spacing, z + spacing], - [x + spacing, y + spacing, z + spacing], - [x + spacing, y + spacing, z], - [x, y + spacing, z], - ] - ) - - def get_index(self, volume_data: torch.Tensor, isolevel: float) -> int: - """ - Calculates the cube_index in the range 0-255 to index - into EDGE_TABLE and FACE_TABLE - Args: - volume_data: the 3D scalar data - isolevel: the isosurface value used as a threshold - for determining whether a point is inside/outside - the volume - """ - cube_index = 0 - bit = 1 - for index in range(len(self.vertices)): - vertex = self.vertices[index] - value = _get_value(vertex, volume_data) + # Calculates cube configuration index given values of the cube vertices + self.cube_index = 0 + for i in range(8): + v = self.verts[INDEX[i]] + value = volume[v[2]][v[1]][v[0]] if value < isolevel: - cube_index |= bit - bit *= 2 - return cube_index + self.cube_index |= 1 << i + + def get_vpair_from_edge(self, edge: int, W: int, H: int) -> Tuple[int, int]: + """ + Get a tuple of global vertex ID from a local edge ID + Global vertex ID is calculated as (x + dx) + (y + dy) * W + (z + dz) * W * H + + Args: + edge: local edge ID in the cube + bfl_vertex: bottom-front-left coordinate of the cube + + Returns: + a pair of global vertex ID + """ + v1, v2 = EDGE_TO_VERTICES[edge] # two end-points on the edge + v1_id = self.verts[v1][0] + self.verts[v1][1] * W + self.verts[v1][2] * W * H + v2_id = self.verts[v2][0] + self.verts[v2][1] * W + self.verts[v2][2] * W * H + return (v1_id, v2_id) + + def vert_interp( + self, + isolevel: float, + edge: int, + vol: torch.Tensor, + ) -> List: + """ + Linearly interpolate a vertex where an isosurface cuts an edge + between the two endpoint vertices, based on their values + + Args: + isolevel: the isosurface value to use as the threshold to determine + whether points are within a volume. + edge: edge (ID) to interpolate + cube: current cube vertices + vol: 3D scalar field + + Returns: + interpolated vertex: position of the interpolated vertex on the edge + """ + v1, v2 = EDGE_TO_VERTICES[edge] + p1, p2 = self.verts[v1], self.verts[v2] + val1, val2 = ( + vol[p1[2]][p1[1]][p1[0]], + vol[p2[2]][p2[1]][p2[0]], + ) + point = None + if abs(isolevel - val1) < EPS: + point = p1 + elif abs(isolevel - val2) < EPS: + point = p2 + elif abs(val1 - val2) < EPS: + point = p1 + + if point is None: + mu = (isolevel - val1) / (val2 - val1) + x1, y1, z1 = p1 + x2, y2, z2 = p2 + x = x1 + mu * (x2 - x1) + y = y1 + mu * (y2 - y1) + z = z1 + mu * (z2 - z1) + else: + x, y, z = point + return [x, y, z] def marching_cubes_naive( - volume_data_batch: torch.Tensor, + vol_batch: torch.Tensor, isolevel: Optional[float] = None, - spacing: int = 1, return_local_coords: bool = True, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """ Runs the classic marching cubes algorithm, iterating over - the coordinates of the volume_data and using a given isolevel - for determining intersected edges of cubes of size `spacing`. + the coordinates of the volume and using a given isolevel + for determining intersected edges of cubes. Returns vertices and faces of the obtained mesh. This operation is non-differentiable. - This is a naive implementation, and is not optimized for efficiency. - Args: - volume_data_batch: a Tensor of size (N, D, H, W) corresponding to + vol_batch: a Tensor of size (N, D, H, W) corresponding to a batch of 3D scalar fields isolevel: the isosurface value to use as the threshold to determine whether points are within a volume. If None, then the average of the maximum and minimum value of the scalar field will be used. - spacing: an integer specifying the cube size to use return_local_coords: bool. If True the output vertices will be in local coordinates in the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range [0, W-1] x [0, H-1] x [0, D-1] Returns: - verts: [(V_0, 3), (V_1, 3), ...] List of N FloatTensors of vertices. - faces: [(F_0, 3), (F_1, 3), ...] List of N LongTensors of faces. + verts: [{V_0}, {V_1}, ...] List of N sets of vertices of shape (|V_i|, 3) in FloatTensor + faces: [{F_0}, {F_1}, ...] List of N sets of faces of shape (|F_i|, 3) in LongTensors """ - volume_data_batch = volume_data_batch.detach().cpu() batched_verts, batched_faces = [], [] - D, H, W = volume_data_batch.shape[1:] - volume_size_xyz = volume_data_batch.new_tensor([W, H, D])[None] + D, H, W = vol_batch.shape[1:] - if return_local_coords: - # Convert from local coordinates in the range [-1, 1] range to - # world coordinates in the range [0, D-1], [0, H-1], [0, W-1] - local_to_world_transform = Translate( - x=+1.0, y=+1.0, z=+1.0, device=volume_data_batch.device - ).scale((volume_size_xyz - 1) * spacing * 0.5) - # Perform the inverse to go from world to local - world_to_local_transform = local_to_world_transform.inverse() + # each edge is represented with its two endpoints (represented with global id) + for i in range(len(vol_batch)): + vol = vol_batch[i] + thresh = ((vol.max() + vol.min()) / 2).item() if isolevel is None else isolevel + vpair_to_edge = {} # maps from tuple of edge endpoints to edge_id + edge_id_to_v = {} # maps from edge ID to vertex position + uniq_edge_id = {} # unique edge IDs + verts = [] # store vertex positions + faces = [] # store face indices + # enumerate each cell in the 3d grid + for z in range(0, D - 1): + for y in range(0, H - 1): + for x in range(0, W - 1): + cube = Cube((x, y, z), vol, thresh) + edge_indices = FACE_TABLE[cube.cube_index] + # cube is entirely in/out of the surface + if len(edge_indices) == 0: + continue + + # gather mesh vertices/faces by processing each cube + interp_points = [[0.0, 0.0, 0.0]] * 12 + # triangle vertex IDs and positions + tri = [] + ps = [] + for i, edge in enumerate(edge_indices): + interp_points[edge] = cube.vert_interp(thresh, edge, vol) + + # Bind interpolated vertex with a global edge_id, which + # is represented by a pair of vertex ids (v1_id, v2_id) + # corresponding to a local edge. + (v1_id, v2_id) = cube.get_vpair_from_edge(edge, W, H) + edge_id = vpair_to_edge.setdefault( + (v1_id, v2_id), len(vpair_to_edge) + ) + tri.append(edge_id) + ps.append(interp_points[edge]) + # when the isolevel are the same as the edge endpoints, the interploated + # vertices can share the same values, and lead to degenerate triangles. + if ( + (i + 1) % 3 == 0 + and ps[0] != ps[1] + and ps[1] != ps[2] + and ps[2] != ps[0] + ): + for j, edge_id in enumerate(tri): + edge_id_to_v[edge_id] = ps[j] + if edge_id not in uniq_edge_id: + uniq_edge_id[edge_id] = len(verts) + verts.append(edge_id_to_v[edge_id]) + faces.append([uniq_edge_id[tri[j]] for j in range(3)]) + tri = [] + ps = [] - for i in range(len(volume_data_batch)): - volume_data = volume_data_batch[i] - curr_isolevel = ( - ((volume_data.max() + volume_data.min()) / 2).item() - if isolevel is None - else isolevel - ) - edge_vertices_to_index = {} - vertex_coords_to_index = {} - verts, faces = [], [] - # Use length - spacing for the bounds since we are using - # cubes of size spacing, with the lowest x,y,z values - # (bottom front left) - for x in range(0, W - spacing, spacing): - for y in range(0, H - spacing, spacing): - for z in range(0, D - spacing, spacing): - cube = Cube((x, y, z), spacing) - new_verts, new_faces = polygonise( - cube, - curr_isolevel, - volume_data, - edge_vertices_to_index, - vertex_coords_to_index, - ) - verts.extend(new_verts) - faces.extend(new_faces) if len(faces) > 0 and len(verts) > 0: - verts = torch.tensor(verts, dtype=torch.float32) - # Convert vertices from world to local coords + verts = torch.tensor(verts, dtype=vol.dtype) + # Convert from world coordinates ([0, D-1], [0, H-1], [0, W-1]) to + # local coordinates in the range [-1, 1] if return_local_coords: - verts = world_to_local_transform.transform_points(verts[None, ...]) - verts = verts.squeeze() + verts = ( + Translate(x=+1.0, y=+1.0, z=+1.0, device=vol_batch.device) + .scale((vol_batch.new_tensor([W, H, D])[None] - 1) * 0.5) + .inverse() + ).transform_points(verts[None])[0] batched_verts.append(verts) batched_faces.append(torch.tensor(faces, dtype=torch.int64)) - return batched_verts, batched_faces - - -def polygonise( - cube: Cube, - isolevel: float, - volume_data: torch.Tensor, - edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int], - vertex_coords_to_index: Dict[Tuple[float, float, float], int], -) -> Tuple[list, list]: - """ - Runs the classic marching cubes algorithm for one Cube in the volume. - Returns the vertices and faces for the given cube. - - Args: - cube: a Cube indicating the cube being examined for edges that intersect - the volume data. - isolevel: the isosurface value to use as the threshold to determine - whether points are within a volume. - volume_data: a Tensor of shape (D, H, W) corresponding to - a 3D scalar field - edge_vertices_to_index: A dictionary which maps an edge's two coordinates - to the index of its interpolated point, if that interpolated point - has already been used by a previous point - vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding - index of that vertex, if that point has already been marked as a vertex. - Returns: - verts: List of triangle vertices for the given cube in the volume - faces: List of triangle faces for the given cube in the volume - """ - num_existing_verts = max(edge_vertices_to_index.values(), default=-1) + 1 - verts, faces = [], [] - cube_index = cube.get_index(volume_data, isolevel) - edges = EDGE_TABLE[cube_index] - edge_indices = _get_edge_indices(edges) - if len(edge_indices) == 0: - return [], [] - - new_verts, edge_index_to_point_index = _calculate_interp_vertices( - edge_indices, - volume_data, - cube, - isolevel, - edge_vertices_to_index, - vertex_coords_to_index, - num_existing_verts, - ) - - # Create faces - face_triangles = FACE_TABLE[cube_index] - for i in range(0, len(face_triangles), 3): - tri1 = edge_index_to_point_index[face_triangles[i]] - tri2 = edge_index_to_point_index[face_triangles[i + 1]] - tri3 = edge_index_to_point_index[face_triangles[i + 2]] - if tri1 != tri2 and tri2 != tri3 and tri1 != tri3: - faces.append([tri1, tri2, tri3]) - - verts += new_verts - return verts, faces - - -def _get_edge_indices(edges: int) -> List[int]: - """ - Finds which edge numbers are intersected given the bit representation - detailed in marching_cubes_data.EDGE_TABLE. - - Args: - edges: an integer corresponding to the value at cube_index - from the EDGE_TABLE in marching_cubes_data.py - - Returns: - edge_indices: A list of edge indices - """ - if edges == 0: - return [] - - edge_indices = [] - for i in range(12): - if edges & (2**i): - edge_indices.append(i) - return edge_indices - - -def _calculate_interp_vertices( - edge_indices: List[int], - volume_data: torch.Tensor, - cube: Cube, - isolevel: float, - edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int], - vertex_coords_to_index: Dict[Tuple[float, float, float], int], - num_existing_verts: int, -) -> Tuple[List, Dict[int, int]]: - """ - Finds the interpolated vertices for the intersected edges, either referencing - previous calculations or newly calculating and storing the new interpolated - points. - - Args: - edge_indices: the numbers of the edges which are intersected. See the - Cube class for more detail on the edge numbering convention. - volume_data: a Tensor of size (D, H, W) corresponding to - a 3D scalar field - cube: a Cube indicating the cube being examined for edges that intersect - the volume - isolevel: the isosurface value to use as the threshold to determine - whether points are within a volume. - edge_vertices_to_index: A dictionary which maps an edge's two coordinates - to the index of its interpolated point, if that interpolated point - has already been used by a previous point - vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding - index of that vertex, if that point has already been marked as a vertex. - num_existing_verts: the number of vertices that have been found in previous - calls to polygonise for the given volume_data in the above function, marching_cubes. - This is equal to the 1 + the maximum value in edge_vertices_to_index. - Returns: - interp_points: a list of new interpolated points - edge_index_to_point_index: a dictionary mapping an edge number to the index in the - marching cubes' vertices list of the interpolated point on that edge. To be precise, - it refers to the index within the vertices list after interp_points - has been appended to the verts list constructed in the marching_cubes_naive - function. - """ - interp_points = [] - edge_index_to_point_index = {} - for edge_index in edge_indices: - v1, v2 = EDGE_TO_VERTICES[edge_index] - point1, point2 = cube.vertices[v1], cube.vertices[v2] - p_tuple1, p_tuple2 = tuple(point1.tolist()), tuple(point2.tolist()) - if (p_tuple1, p_tuple2) in edge_vertices_to_index: - edge_index_to_point_index[edge_index] = edge_vertices_to_index[ - (p_tuple1, p_tuple2) - ] else: - val1, val2 = _get_value(point1, volume_data), _get_value( - point2, volume_data - ) - - point = None - if abs(isolevel - val1) < EPS: - point = point1 - - if abs(isolevel - val2) < EPS: - point = point2 - - if abs(val1 - val2) < EPS: - point = point1 - - if point is None: - mu = (isolevel - val1) / (val2 - val1) - x1, y1, z1 = point1 - x2, y2, z2 = point2 - x = x1 + mu * (x2 - x1) - y = y1 + mu * (y2 - y1) - z = z1 + mu * (z2 - z1) - else: - x, y, z = point - - x, y, z = x.item(), y.item(), z.item() # for dictionary keys - - vert_index = None - if (x, y, z) in vertex_coords_to_index: - vert_index = vertex_coords_to_index[(x, y, z)] - else: - vert_index = num_existing_verts + len(interp_points) - interp_points.append([x, y, z]) - vertex_coords_to_index[(x, y, z)] = vert_index - - edge_vertices_to_index[(p_tuple1, p_tuple2)] = vert_index - edge_index_to_point_index[edge_index] = vert_index - - return interp_points, edge_index_to_point_index - - -def _get_value(point: Tuple[int, int, int], volume_data: torch.Tensor) -> float: - """ - Gets the value at a given coordinate point in the scalar field. - - Args: - point: data of shape (3) corresponding to an xyz coordinate. - volume_data: a Tensor of size (D, H, W) corresponding to - a 3D scalar field - Returns: - data: scalar value in the volume at the given point - """ - x, y, z = point - # pyre-fixme[7]: Expected `float` but got `Tensor`. - return volume_data[z][y][x] + batched_verts.append([]) + batched_faces.append([]) + return batched_verts, batched_faces diff --git a/pytorch3d/ops/marching_cubes_data.py b/pytorch3d/ops/marching_cubes_data.py index 8c3203f3..802f67da 100644 --- a/pytorch3d/ops/marching_cubes_data.py +++ b/pytorch3d/ops/marching_cubes_data.py @@ -4,284 +4,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# A length 256 list which maps a cubeindex to a number -# with the intersected edges' bits set to 1. -# Each cubeindex corresponds to a given cube configuration, where -# it is composed of a bitstring where the 0th bit is flipped if vertex 0 -# is below the isosurface (i.e. 0x01), for each of the 8 vertices. -EDGE_TABLE = [ - 0x0, - 0x109, - 0x203, - 0x30A, - 0x406, - 0x50F, - 0x605, - 0x70C, - 0x80C, - 0x905, - 0xA0F, - 0xB06, - 0xC0A, - 0xD03, - 0xE09, - 0xF00, - 0x190, - 0x99, - 0x393, - 0x29A, - 0x596, - 0x49F, - 0x795, - 0x69C, - 0x99C, - 0x895, - 0xB9F, - 0xA96, - 0xD9A, - 0xC93, - 0xF99, - 0xE90, - 0x230, - 0x339, - 0x33, - 0x13A, - 0x636, - 0x73F, - 0x435, - 0x53C, - 0xA3C, - 0xB35, - 0x83F, - 0x936, - 0xE3A, - 0xF33, - 0xC39, - 0xD30, - 0x3A0, - 0x2A9, - 0x1A3, - 0xAA, - 0x7A6, - 0x6AF, - 0x5A5, - 0x4AC, - 0xBAC, - 0xAA5, - 0x9AF, - 0x8A6, - 0xFAA, - 0xEA3, - 0xDA9, - 0xCA0, - 0x460, - 0x569, - 0x663, - 0x76A, - 0x66, - 0x16F, - 0x265, - 0x36C, - 0xC6C, - 0xD65, - 0xE6F, - 0xF66, - 0x86A, - 0x963, - 0xA69, - 0xB60, - 0x5F0, - 0x4F9, - 0x7F3, - 0x6FA, - 0x1F6, - 0xFF, - 0x3F5, - 0x2FC, - 0xDFC, - 0xCF5, - 0xFFF, - 0xEF6, - 0x9FA, - 0x8F3, - 0xBF9, - 0xAF0, - 0x650, - 0x759, - 0x453, - 0x55A, - 0x256, - 0x35F, - 0x55, - 0x15C, - 0xE5C, - 0xF55, - 0xC5F, - 0xD56, - 0xA5A, - 0xB53, - 0x859, - 0x950, - 0x7C0, - 0x6C9, - 0x5C3, - 0x4CA, - 0x3C6, - 0x2CF, - 0x1C5, - 0xCC, - 0xFCC, - 0xEC5, - 0xDCF, - 0xCC6, - 0xBCA, - 0xAC3, - 0x9C9, - 0x8C0, - 0x8C0, - 0x9C9, - 0xAC3, - 0xBCA, - 0xCC6, - 0xDCF, - 0xEC5, - 0xFCC, - 0xCC, - 0x1C5, - 0x2CF, - 0x3C6, - 0x4CA, - 0x5C3, - 0x6C9, - 0x7C0, - 0x950, - 0x859, - 0xB53, - 0xA5A, - 0xD56, - 0xC5F, - 0xF55, - 0xE5C, - 0x15C, - 0x55, - 0x35F, - 0x256, - 0x55A, - 0x453, - 0x759, - 0x650, - 0xAF0, - 0xBF9, - 0x8F3, - 0x9FA, - 0xEF6, - 0xFFF, - 0xCF5, - 0xDFC, - 0x2FC, - 0x3F5, - 0xFF, - 0x1F6, - 0x6FA, - 0x7F3, - 0x4F9, - 0x5F0, - 0xB60, - 0xA69, - 0x963, - 0x86A, - 0xF66, - 0xE6F, - 0xD65, - 0xC6C, - 0x36C, - 0x265, - 0x16F, - 0x66, - 0x76A, - 0x663, - 0x569, - 0x460, - 0xCA0, - 0xDA9, - 0xEA3, - 0xFAA, - 0x8A6, - 0x9AF, - 0xAA5, - 0xBAC, - 0x4AC, - 0x5A5, - 0x6AF, - 0x7A6, - 0xAA, - 0x1A3, - 0x2A9, - 0x3A0, - 0xD30, - 0xC39, - 0xF33, - 0xE3A, - 0x936, - 0x83F, - 0xB35, - 0xA3C, - 0x53C, - 0x435, - 0x73F, - 0x636, - 0x13A, - 0x33, - 0x339, - 0x230, - 0xE90, - 0xF99, - 0xC93, - 0xD9A, - 0xA96, - 0xB9F, - 0x895, - 0x99C, - 0x69C, - 0x795, - 0x49F, - 0x596, - 0x29A, - 0x393, - 0x99, - 0x190, - 0xF00, - 0xE09, - 0xD03, - 0xC0A, - 0xB06, - 0xA0F, - 0x905, - 0x80C, - 0x70C, - 0x605, - 0x50F, - 0x406, - 0x30A, - 0x203, - 0x109, - 0x0, -] # Maps each edge (by index) to the corresponding cube vertices EDGE_TO_VERTICES = [ [0, 1], - [1, 2], - [3, 2], - [0, 3], - [4, 5], - [5, 6], - [7, 6], - [4, 7], - [0, 4], [1, 5], - [2, 6], + [4, 5], + [0, 4], + [2, 3], [3, 7], + [6, 7], + [2, 6], + [0, 2], + [1, 3], + [5, 7], + [4, 6], ] # A list of lists mapping a cube_index (a given configuration) @@ -547,3 +284,6 @@ FACE_TABLE = [ [0, 3, 8], [], ] + +# mapping from 0-7 to v0-v7 in cube.vertices +INDEX = [0, 1, 5, 4, 2, 3, 7, 6] diff --git a/tests/data/test_marching_cubes_data/double_ellipsoid.pickle b/tests/data/test_marching_cubes_data/double_ellipsoid.pickle index 3642a73d..f85dbc7f 100644 Binary files a/tests/data/test_marching_cubes_data/double_ellipsoid.pickle and b/tests/data/test_marching_cubes_data/double_ellipsoid.pickle differ diff --git a/tests/data/test_marching_cubes_data/sphere_level64.pickle b/tests/data/test_marching_cubes_data/sphere_level64.pickle index 2cfafe81..f31bff7b 100644 Binary files a/tests/data/test_marching_cubes_data/sphere_level64.pickle and b/tests/data/test_marching_cubes_data/sphere_level64.pickle differ diff --git a/tests/test_marching_cubes.py b/tests/test_marching_cubes.py index fd63c6ab..06442e94 100644 --- a/tests/test_marching_cubes.py +++ b/tests/test_marching_cubes.py @@ -32,8 +32,8 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) - expected_verts = torch.tensor([]) - expected_faces = torch.tensor([], dtype=torch.int64) + expected_verts = torch.tensor([[]]) + expected_faces = torch.tensor([[]], dtype=torch.int64) self.assertClose(verts, expected_verts) self.assertClose(faces, expected_faces) @@ -42,16 +42,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): volume_data[0, 0, 0, 0] = 0 volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) - expected_verts = torch.tensor( [ [0.5, 0, 0], - [0, 0, 0.5], [0, 0.5, 0], + [0, 0, 0.5], ] ) - expected_faces = torch.tensor([[1, 2, 0]]) + expected_faces = torch.tensor([[0, 1, 2]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -69,12 +68,12 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ [1.0000, 0.0000, 0.5000], + [0.0000, 0.5000, 0.0000], [0.0000, 0.0000, 0.5000], [1.0000, 0.5000, 0.0000], - [0.0000, 0.5000, 0.0000], ] ) - expected_faces = torch.tensor([[1, 2, 0], [3, 2, 1]]) + expected_faces = torch.tensor([[0, 1, 2], [3, 1, 0]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -92,15 +91,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ - [0.5000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.5000], + [1.0000, 0.5000, 0.0000], [1.0000, 1.0000, 0.5000], [0.5000, 1.0000, 0.0000], - [1.0000, 0.5000, 0.0000], + [0.5000, 0.0000, 0.0000], [0.0000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.5000], ] ) - expected_faces = torch.tensor([[0, 1, 5], [4, 3, 2]]) + expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -119,14 +118,14 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ - [0.5000, 0.0000, 0.0000], [0.0000, 0.0000, 0.5000], + [1.0000, 0.5000, 0.0000], + [0.5000, 0.0000, 0.0000], [0.0000, 0.5000, 1.0000], [1.0000, 0.5000, 1.0000], - [1.0000, 0.5000, 0.0000], ] ) - expected_faces = torch.tensor([[0, 2, 1], [0, 4, 2], [4, 3, 2]]) + expected_faces = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 4, 1]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -143,15 +142,14 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ - [0.0000, 0.5000, 1.0000], - [1.0000, 0.5000, 1.0000], [1.0000, 0.5000, 0.0000], [0.0000, 0.5000, 0.0000], + [1.0000, 0.5000, 1.0000], + [0.0000, 0.5000, 1.0000], ] ) - expected_faces = torch.tensor([[1, 0, 2], [2, 0, 3]]) - + expected_faces = torch.tensor([[0, 1, 2], [2, 1, 3]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -171,17 +169,17 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ - [0.5000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.5000], [0.5000, 1.0000, 0.0000], [0.0000, 1.0000, 0.5000], + [0.0000, 0.5000, 0.0000], + [1.0000, 0.5000, 0.0000], + [0.5000, 0.0000, 0.0000], [0.0000, 0.5000, 1.0000], [1.0000, 0.5000, 1.0000], - [1.0000, 0.5000, 0.0000], - [0.0000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.5000], ] ) - expected_faces = torch.tensor([[2, 7, 3], [0, 6, 1], [6, 4, 1], [6, 5, 4]]) + expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [3, 5, 6], [5, 4, 7]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -202,22 +200,22 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ - [0.5000, 0.0000, 1.0000], - [1.0000, 0.0000, 0.5000], - [0.5000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.5000], [0.5000, 1.0000, 1.0000], - [1.0000, 1.0000, 0.5000], - [0.5000, 1.0000, 0.0000], - [0.0000, 1.0000, 0.5000], [0.0000, 0.5000, 1.0000], + [0.0000, 1.0000, 0.5000], + [1.0000, 0.0000, 0.5000], + [0.5000, 0.0000, 1.0000], [1.0000, 0.5000, 1.0000], - [1.0000, 0.5000, 0.0000], + [0.5000, 0.0000, 0.0000], [0.0000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.5000], + [0.5000, 1.0000, 0.0000], + [1.0000, 0.5000, 0.0000], + [1.0000, 1.0000, 0.5000], ] ) - expected_faces = torch.tensor([[0, 1, 9], [4, 7, 8], [2, 3, 11], [5, 10, 6]]) + expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -238,15 +236,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ - [1.0000, 0.0000, 0.5000], - [0.5000, 0.0000, 0.0000], - [0.5000, 1.0000, 1.0000], - [0.0000, 1.0000, 0.5000], [1.0000, 0.5000, 1.0000], + [0.0000, 1.0000, 0.5000], + [0.5000, 1.0000, 1.0000], + [1.0000, 0.0000, 0.5000], [0.0000, 0.5000, 0.0000], + [0.5000, 0.0000, 0.0000], ] ) - expected_faces = torch.tensor([[2, 3, 5], [4, 2, 5], [4, 5, 1], [4, 1, 0]]) + expected_faces = torch.tensor([[0, 1, 2], [3, 1, 0], [3, 4, 1], [3, 5, 4]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -269,13 +267,13 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): [ [0.5000, 0.0000, 0.0000], [0.0000, 0.0000, 0.5000], - [0.5000, 1.0000, 1.0000], [0.0000, 1.0000, 0.5000], [1.0000, 0.5000, 1.0000], [1.0000, 0.5000, 0.0000], + [0.5000, 1.0000, 1.0000], ] ) - expected_faces = torch.tensor([[0, 5, 4], [0, 4, 3], [0, 3, 1], [3, 4, 2]]) + expected_faces = torch.tensor([[0, 1, 2], [0, 2, 3], [0, 3, 4], [5, 3, 2]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -295,15 +293,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ [0.5000, 0.0000, 0.0000], + [0.0000, 0.5000, 0.0000], [0.0000, 0.0000, 0.5000], - [0.5000, 1.0000, 1.0000], [1.0000, 1.0000, 0.5000], [1.0000, 0.5000, 1.0000], - [0.0000, 0.5000, 0.0000], + [0.5000, 1.0000, 1.0000], ] ) - expected_faces = torch.tensor([[4, 3, 2], [0, 1, 5]]) + expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -324,16 +322,16 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ [1.0000, 0.0000, 0.5000], + [0.0000, 0.5000, 0.0000], [0.0000, 0.0000, 0.5000], - [0.5000, 1.0000, 1.0000], + [1.0000, 0.5000, 0.0000], [1.0000, 1.0000, 0.5000], [1.0000, 0.5000, 1.0000], - [1.0000, 0.5000, 0.0000], - [0.0000, 0.5000, 0.0000], + [0.5000, 1.0000, 1.0000], ] ) - expected_faces = torch.tensor([[5, 1, 6], [5, 0, 1], [4, 3, 2]]) + expected_faces = torch.tensor([[0, 1, 2], [0, 3, 1], [4, 5, 6]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -354,18 +352,18 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ [1.0000, 0.0000, 0.5000], + [1.0000, 0.5000, 0.0000], [0.5000, 0.0000, 0.0000], - [0.5000, 1.0000, 1.0000], [1.0000, 1.0000, 0.5000], + [1.0000, 0.5000, 1.0000], + [0.5000, 1.0000, 1.0000], + [0.0000, 0.5000, 0.0000], [0.5000, 1.0000, 0.0000], [0.0000, 1.0000, 0.5000], - [1.0000, 0.5000, 1.0000], - [1.0000, 0.5000, 0.0000], - [0.0000, 0.5000, 0.0000], ] ) - expected_faces = torch.tensor([[6, 3, 2], [7, 0, 1], [5, 4, 8]]) + expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -386,18 +384,18 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ - [0.5000, 0.0000, 1.0000], [1.0000, 0.0000, 0.5000], - [0.5000, 0.0000, 0.0000], - [0.0000, 0.0000, 0.5000], - [0.5000, 1.0000, 1.0000], + [0.5000, 0.0000, 1.0000], [1.0000, 1.0000, 0.5000], + [0.5000, 1.0000, 1.0000], + [0.0000, 0.0000, 0.5000], + [0.5000, 0.0000, 0.0000], [0.5000, 1.0000, 0.0000], [0.0000, 1.0000, 0.5000], ] ) - expected_faces = torch.tensor([[3, 6, 2], [3, 7, 6], [1, 5, 0], [5, 4, 0]]) + expected_faces = torch.tensor([[0, 1, 2], [2, 1, 3], [4, 5, 6], [4, 6, 7]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -418,16 +416,16 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ - [1.0000, 0.0000, 0.5000], [0.5000, 0.0000, 0.0000], - [0.5000, 1.0000, 1.0000], - [1.0000, 1.0000, 0.5000], - [0.0000, 0.5000, 1.0000], [0.0000, 0.5000, 0.0000], + [0.0000, 0.5000, 1.0000], + [1.0000, 1.0000, 0.5000], + [1.0000, 0.0000, 0.5000], + [0.5000, 1.0000, 1.0000], ] ) - expected_faces = torch.tensor([[1, 0, 3], [1, 3, 4], [1, 4, 5], [2, 4, 3]]) + expected_faces = torch.tensor([[0, 1, 2], [0, 2, 3], [0, 3, 4], [3, 2, 5]]) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -447,27 +445,26 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ - [0.5, 1, 1], - [1, 1, 0.5], - [1, 0.5, 1], - [1, 1, 1.5], - [1, 1.5, 1], - [1.5, 1, 1], + [1.0000, 0.5000, 1.0000], + [1.0000, 1.0000, 0.5000], + [0.5000, 1.0000, 1.0000], + [1.5000, 1.0000, 1.0000], + [1.0000, 1.5000, 1.0000], + [1.0000, 1.0000, 1.5000], ] ) expected_faces = torch.tensor( [ - [2, 0, 1], - [2, 3, 0], - [0, 4, 1], - [3, 4, 0], - [5, 2, 1], - [3, 2, 5], - [5, 1, 4], + [0, 1, 2], + [1, 0, 3], + [1, 4, 2], + [1, 3, 4], + [0, 2, 5], + [3, 0, 5], + [2, 4, 5], [3, 5, 4], ] ) - self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -492,76 +489,76 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ - [0.9000, 1.0000, 1.0000], - [1.0000, 1.0000, 0.9000], [1.0000, 0.9000, 1.0000], - [0.9000, 1.0000, 2.0000], - [1.0000, 0.9000, 2.0000], - [1.0000, 1.0000, 2.1000], - [0.9000, 2.0000, 1.0000], - [1.0000, 2.0000, 0.9000], - [0.9000, 2.0000, 2.0000], - [1.0000, 2.0000, 2.1000], - [1.0000, 2.1000, 1.0000], - [1.0000, 2.1000, 2.0000], - [2.0000, 1.0000, 0.9000], + [1.0000, 1.0000, 0.9000], + [0.9000, 1.0000, 1.0000], [2.0000, 0.9000, 1.0000], - [2.0000, 0.9000, 2.0000], - [2.0000, 1.0000, 2.1000], - [2.0000, 2.0000, 0.9000], - [2.0000, 2.0000, 2.1000], - [2.0000, 2.1000, 1.0000], - [2.0000, 2.1000, 2.0000], + [2.0000, 1.0000, 0.9000], [2.1000, 1.0000, 1.0000], - [2.1000, 1.0000, 2.0000], + [1.0000, 2.0000, 0.9000], + [0.9000, 2.0000, 1.0000], + [2.0000, 2.0000, 0.9000], [2.1000, 2.0000, 1.0000], + [1.0000, 2.1000, 1.0000], + [2.0000, 2.1000, 1.0000], + [1.0000, 0.9000, 2.0000], + [0.9000, 1.0000, 2.0000], + [2.0000, 0.9000, 2.0000], + [2.1000, 1.0000, 2.0000], + [0.9000, 2.0000, 2.0000], [2.1000, 2.0000, 2.0000], + [1.0000, 2.1000, 2.0000], + [2.0000, 2.1000, 2.0000], + [1.0000, 1.0000, 2.1000], + [2.0000, 1.0000, 2.1000], + [1.0000, 2.0000, 2.1000], + [2.0000, 2.0000, 2.1000], ] ) expected_faces = torch.tensor( [ - [2, 0, 1], - [2, 4, 3], - [0, 2, 3], - [4, 5, 3], - [0, 6, 7], - [1, 0, 7], - [3, 8, 0], - [8, 6, 0], - [5, 9, 8], - [3, 5, 8], + [0, 1, 2], + [0, 3, 4], + [1, 0, 4], + [4, 3, 5], + [1, 6, 7], + [2, 1, 7], + [4, 8, 1], + [1, 8, 6], + [8, 4, 5], + [9, 8, 5], [6, 10, 7], - [11, 10, 6], - [8, 11, 6], - [9, 11, 8], - [13, 2, 1], - [12, 13, 1], - [14, 4, 13], - [13, 4, 2], - [4, 14, 15], - [5, 4, 15], - [12, 1, 16], - [1, 7, 16], - [15, 17, 5], - [5, 17, 9], - [16, 7, 10], - [18, 16, 10], - [19, 18, 11], - [18, 10, 11], + [6, 8, 11], + [10, 6, 11], + [8, 9, 11], + [12, 0, 2], + [13, 12, 2], + [3, 0, 14], + [14, 0, 12], + [15, 5, 3], + [14, 15, 3], + [2, 7, 13], + [7, 16, 13], + [5, 15, 9], + [9, 15, 17], + [10, 18, 16], + [7, 10, 16], + [11, 19, 10], + [19, 18, 10], [9, 17, 19], [11, 9, 19], - [20, 13, 12], - [20, 21, 14], - [13, 20, 14], + [12, 13, 20], + [14, 12, 20], + [21, 14, 20], [15, 14, 21], - [22, 20, 12], - [16, 22, 12], + [13, 16, 22], + [20, 13, 22], [21, 20, 23], - [23, 20, 22], + [20, 22, 23], [17, 15, 21], [23, 17, 21], - [22, 16, 18], + [16, 18, 22], [23, 22, 18], [19, 23, 18], [17, 23, 19], @@ -569,6 +566,7 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): ) self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=True) expected_verts = convert_to_local(expected_verts, 5) self.assertClose(verts[0], expected_verts) @@ -592,34 +590,49 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): expected_verts = torch.tensor( [ + [2.0, 1.0, 1.0], + [2.0, 2.0, 1.0], [1.0, 1.0, 1.0], - [1.0, 1.0, 2.0], [1.0, 2.0, 1.0], + [2.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [2.0, 1.0, 2.0], + [1.0, 1.0, 2.0], + [1.0, 1.0, 1.0], + [1.0, 2.0, 1.0], + [1.0, 1.0, 2.0], [1.0, 2.0, 2.0], [2.0, 1.0, 1.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0], + [2.0, 2.0, 1.0], + [2.0, 2.0, 2.0], + [1.0, 2.0, 1.0], + [1.0, 2.0, 2.0], + [2.0, 1.0, 2.0], + [1.0, 1.0, 2.0], + [2.0, 2.0, 2.0], + [1.0, 2.0, 2.0], ] ) expected_faces = torch.tensor( [ - [1, 3, 0], - [3, 2, 0], - [5, 1, 4], - [4, 1, 0], - [4, 0, 6], - [0, 2, 6], - [5, 7, 1], - [1, 7, 3], - [7, 6, 3], - [6, 2, 3], - [5, 4, 7], - [7, 4, 6], + [0, 1, 2], + [2, 1, 3], + [4, 5, 6], + [6, 5, 7], + [8, 9, 10], + [9, 11, 10], + [12, 13, 14], + [14, 13, 15], + [16, 17, 18], + [17, 19, 18], + [20, 21, 22], + [21, 23, 22], ] ) - self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -651,8 +664,8 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): filename = os.path.join(DATA_DIR, data_filename) with open(filename, "rb") as file: verts_and_faces = pickle.load(file) - expected_verts = verts_and_faces["verts"].squeeze() - expected_faces = verts_and_faces["faces"].squeeze() + expected_verts = verts_and_faces["verts"] + expected_faces = verts_and_faces["faces"] self.assertClose(verts[0], expected_verts) self.assertClose(faces[0], expected_faces) @@ -689,8 +702,8 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): expected_verts = verts_and_faces["verts"] expected_faces = verts_and_faces["faces"] - self.assertClose(verts[0], expected_verts[0]) - self.assertClose(faces[0], expected_faces[0]) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) def test_cube_surface_area(self): if USE_SCIKIT: @@ -760,16 +773,26 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase): self.assertClose(surf, surf_sci) + def test_ball_example(self): + N = 15 + axis_tensor = torch.arange(0, N) + X, Y, Z = torch.meshgrid(axis_tensor, axis_tensor, axis_tensor, indexing="ij") + u = (X - 15) ** 2 + (Y - 15) ** 2 + (Z - 15) ** 2 - 8**2 + u = u[None].float() + verts, faces = marching_cubes_naive(u, 0, return_local_coords=False) + @staticmethod - def marching_cubes_with_init(batch_size: int, V: int): + def marching_cubes_with_init(algo_type: str, batch_size: int, V: int): device = torch.device("cuda:0") volume_data = torch.rand( (batch_size, V, V, V), dtype=torch.float32, device=device ) - torch.cuda.synchronize() + algo_table = { + "naive": marching_cubes_naive, + } def convert(): - marching_cubes_naive(volume_data, return_local_coords=False) + algo_table[algo_type](volume_data, return_local_coords=False) torch.cuda.synchronize() return convert