diff --git a/pytorch3d/ops/marching_cubes.py b/pytorch3d/ops/marching_cubes.py new file mode 100644 index 00000000..e1d6bca8 --- /dev/null +++ b/pytorch3d/ops/marching_cubes.py @@ -0,0 +1,347 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import Dict, List, Optional, Tuple + +import torch +from pytorch3d.ops.marching_cubes_data import EDGE_TABLE, EDGE_TO_VERTICES, FACE_TABLE +from pytorch3d.transforms import Translate + + +EPS = 0.00001 + + +class Cube: + def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1): + """ + Initializes a cube given the bottom front left vertex coordinate + and the cube spacing + + Edge and vertex convention: + + v4_______e4____________v5 + /| /| + / | / | + e7/ | e5/ | + /___|______e6_________/ | + v7| | |v6 |e9 + | | | | + | |e8 |e10| + e11| | | | + | |_________________|___| + | / v0 e0 | /v1 + | / | / + | /e3 | /e1 + |/_____________________|/ + v3 e2 v2 + + Args: + bfl_vertex: a tuple of size 3 corresponding to the bottom front left vertex + of the cube in (x, y, z) format + spacing: the length of each edge of the cube + """ + # match corner orders to algorithm convention + if len(bfl_vertex) != 3: + msg = "The vertex {} is size {} instead of size 3".format( + bfl_vertex, len(bfl_vertex) + ) + raise ValueError(msg) + + x, y, z = bfl_vertex + self.vertices = torch.tensor( + [ + [x, y, z + spacing], + [x + spacing, y, z + spacing], + [x + spacing, y, z], + [x, y, z], + [x, y + spacing, z + spacing], + [x + spacing, y + spacing, z + spacing], + [x + spacing, y + spacing, z], + [x, y + spacing, z], + ] + ) + + def get_index(self, volume_data: torch.Tensor, isolevel: float) -> int: + """ + Calculates the cube_index in the range 0-255 to index + into EDGE_TABLE and FACE_TABLE + Args: + volume_data: the 3D scalar data + isolevel: the isosurface value used as a threshold + for determining whether a point is inside/outside + the volume + """ + cube_index = 0 + bit = 1 + for index in range(len(self.vertices)): + vertex = self.vertices[index] + value = _get_value(vertex, volume_data) + if value < isolevel: + cube_index |= bit + bit *= 2 + return cube_index + + +def marching_cubes_naive( + volume_data_batch: torch.Tensor, + isolevel: Optional[float] = None, + spacing: int = 1, + return_local_coords: bool = True, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Runs the classic marching cubes algorithm, iterating over + the coordinates of the volume_data and using a given isolevel + for determining intersected edges of cubes of size `spacing`. + Returns vertices and faces of the obtained mesh. + This operation is non-differentiable. + + This is a naive implementation, and is not optimized for efficiency. + + Args: + volume_data_batch: a Tensor of size (N, D, H, W) corresponding to + a batch of 3D scalar fields + isolevel: the isosurface value to use as the threshold to determine + whether points are within a volume. If None, then the average of the + maximum and minimum value of the scalar field will be used. + spacing: an integer specifying the cube size to use + return_local_coords: bool. If True the output vertices will be in local coordinates in + the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range + [0, W-1] x [0, H-1] x [0, D-1] + Returns: + verts: [(V_0, 3), (V_1, 3), ...] List of N FloatTensors of vertices. + faces: [(F_0, 3), (F_1, 3), ...] List of N LongTensors of faces. + """ + volume_data_batch = volume_data_batch.detach().cpu() + batched_verts, batched_faces = [], [] + D, H, W = volume_data_batch.shape[1:] + # pyre-ignore [16] + volume_size_xyz = volume_data_batch.new_tensor([W, H, D])[None] + + if return_local_coords: + # Convert from local coordinates in the range [-1, 1] range to + # world coordinates in the range [0, D-1], [0, H-1], [0, W-1] + local_to_world_transform = Translate( + x=+1.0, y=+1.0, z=+1.0, device=volume_data_batch.device + ).scale((volume_size_xyz - 1) * spacing * 0.5) + # Perform the inverse to go from world to local + world_to_local_transform = local_to_world_transform.inverse() + + for i in range(len(volume_data_batch)): + volume_data = volume_data_batch[i] + curr_isolevel = ( + ((volume_data.max() + volume_data.min()) / 2).item() + if isolevel is None + else isolevel + ) + edge_vertices_to_index = {} + vertex_coords_to_index = {} + verts, faces = [], [] + # Use length - spacing for the bounds since we are using + # cubes of size spacing, with the lowest x,y,z values + # (bottom front left) + for x in range(0, W - spacing, spacing): + for y in range(0, H - spacing, spacing): + for z in range(0, D - spacing, spacing): + cube = Cube((x, y, z), spacing) + new_verts, new_faces = polygonise( + cube, + curr_isolevel, + volume_data, + edge_vertices_to_index, + vertex_coords_to_index, + ) + verts.extend(new_verts) + faces.extend(new_faces) + if len(faces) > 0 and len(verts) > 0: + verts = torch.tensor(verts, dtype=torch.float32) + # Convert vertices from world to local coords + if return_local_coords: + verts = world_to_local_transform.transform_points(verts[None, ...]) + verts = verts.squeeze() + batched_verts.append(verts) + batched_faces.append(torch.tensor(faces, dtype=torch.int64)) + return batched_verts, batched_faces + + +def polygonise( + cube: Cube, + isolevel: float, + volume_data: torch.Tensor, + edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int], + vertex_coords_to_index: Dict[Tuple[float, float, float], int], +) -> Tuple[list, list]: + """ + Runs the classic marching cubes algorithm for one Cube in the volume. + Returns the vertices and faces for the given cube. + + Args: + cube: a Cube indicating the cube being examined for edges that intersect + the volume data. + isolevel: the isosurface value to use as the threshold to determine + whether points are within a volume. + volume_data: a Tensor of shape (D, H, W) corresponding to + a 3D scalar field + edge_vertices_to_index: A dictionary which maps an edge's two coordinates + to the index of its interpolated point, if that interpolated point + has already been used by a previous point + vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding + index of that vertex, if that point has already been marked as a vertex. + Returns: + verts: List of triangle vertices for the given cube in the volume + faces: List of triangle faces for the given cube in the volume + """ + num_existing_verts = max(edge_vertices_to_index.values(), default=-1) + 1 + verts, faces = [], [] + cube_index = cube.get_index(volume_data, isolevel) + edges = EDGE_TABLE[cube_index] + edge_indices = _get_edge_indices(edges) + if len(edge_indices) == 0: + return [], [] + + new_verts, edge_index_to_point_index = _calculate_interp_vertices( + edge_indices, + volume_data, + cube, + isolevel, + edge_vertices_to_index, + vertex_coords_to_index, + num_existing_verts, + ) + + # Create faces + face_triangles = FACE_TABLE[cube_index] + for i in range(0, len(face_triangles), 3): + tri1 = edge_index_to_point_index[face_triangles[i]] + tri2 = edge_index_to_point_index[face_triangles[i + 1]] + tri3 = edge_index_to_point_index[face_triangles[i + 2]] + if tri1 != tri2 and tri2 != tri3 and tri1 != tri3: + faces.append([tri1, tri2, tri3]) + + verts += new_verts + return verts, faces + + +def _get_edge_indices(edges: int) -> List[int]: + """ + Finds which edge numbers are intersected given the bit representation + detailed in marching_cubes_data.EDGE_TABLE. + + Args: + edges: an integer corresponding to the value at cube_index + from the EDGE_TABLE in marching_cubes_data.py + + Returns: + edge_indices: A list of edge indices + """ + if edges == 0: + return [] + + edge_indices = [] + for i in range(12): + if edges & (2 ** i): + edge_indices.append(i) + return edge_indices + + +def _calculate_interp_vertices( + edge_indices: List[int], + volume_data: torch.Tensor, + cube: Cube, + isolevel: float, + edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int], + vertex_coords_to_index: Dict[Tuple[float, float, float], int], + num_existing_verts: int, +) -> Tuple[List, Dict[int, int]]: + """ + Finds the interpolated vertices for the intersected edges, either referencing + previous calculations or newly calculating and storing the new interpolated + points. + + Args: + edge_indices: the numbers of the edges which are intersected. See the + Cube class for more detail on the edge numbering convention. + volume_data: a Tensor of size (D, H, W) corresponding to + a 3D scalar field + cube: a Cube indicating the cube being examined for edges that intersect + the volume + isolevel: the isosurface value to use as the threshold to determine + whether points are within a volume. + edge_vertices_to_index: A dictionary which maps an edge's two coordinates + to the index of its interpolated point, if that interpolated point + has already been used by a previous point + vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding + index of that vertex, if that point has already been marked as a vertex. + num_existing_verts: the number of vertices that have been found in previous + calls to polygonise for the given volume_data in the above function, marching_cubes. + This is equal to the 1 + the maximum value in edge_vertices_to_index. + Returns: + interp_points: a list of new interpolated points + edge_index_to_point_index: a dictionary mapping an edge number to the index in the + marching cubes' vertices list of the interpolated point on that edge. To be precise, + it refers to the index within the vertices list after interp_points + has been appended to the verts list constructed in the marching_cubes_naive + function. + """ + interp_points = [] + edge_index_to_point_index = {} + for edge_index in edge_indices: + v1, v2 = EDGE_TO_VERTICES[edge_index] + point1, point2 = cube.vertices[v1], cube.vertices[v2] + p_tuple1, p_tuple2 = tuple(point1.tolist()), tuple(point2.tolist()) + if (p_tuple1, p_tuple2) in edge_vertices_to_index: + edge_index_to_point_index[edge_index] = edge_vertices_to_index[ + (p_tuple1, p_tuple2) + ] + else: + val1, val2 = _get_value(point1, volume_data), _get_value( + point2, volume_data + ) + + point = None + if abs(isolevel - val1) < EPS: + point = point1 + + if abs(isolevel - val2) < EPS: + point = point2 + + if abs(val1 - val2) < EPS: + point = point1 + + if point is None: + mu = (isolevel - val1) / (val2 - val1) + x1, y1, z1 = point1 + x2, y2, z2 = point2 + x = x1 + mu * (x2 - x1) + y = y1 + mu * (y2 - y1) + z = z1 + mu * (z2 - z1) + else: + x, y, z = point + + x, y, z = x.item(), y.item(), z.item() # for dictionary keys + + vert_index = None + if (x, y, z) in vertex_coords_to_index: + vert_index = vertex_coords_to_index[(x, y, z)] + else: + vert_index = num_existing_verts + len(interp_points) + interp_points.append([x, y, z]) + vertex_coords_to_index[(x, y, z)] = vert_index + + edge_vertices_to_index[(p_tuple1, p_tuple2)] = vert_index + edge_index_to_point_index[edge_index] = vert_index + + return interp_points, edge_index_to_point_index + + +def _get_value(point: Tuple[int, int, int], volume_data: torch.Tensor) -> float: + """ + Gets the value at a given coordinate point in the scalar field. + + Args: + point: data of shape (3) corresponding to an xyz coordinate. + volume_data: a Tensor of size (D, H, W) corresponding to + a 3D scalar field + Returns: + data: scalar value in the volume at the given point + """ + x, y, z = point + return volume_data[z][y][x] diff --git a/pytorch3d/ops/marching_cubes_data.py b/pytorch3d/ops/marching_cubes_data.py new file mode 100644 index 00000000..8ad92825 --- /dev/null +++ b/pytorch3d/ops/marching_cubes_data.py @@ -0,0 +1,545 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +# A length 256 list which maps a cubeindex to a number +# with the intersected edges' bits set to 1. +# Each cubeindex corresponds to a given cube configuration, where +# it is composed of a bitstring where the 0th bit is flipped if vertex 0 +# is below the isosurface (i.e. 0x01), for each of the 8 vertices. +EDGE_TABLE = [ + 0x0, + 0x109, + 0x203, + 0x30A, + 0x406, + 0x50F, + 0x605, + 0x70C, + 0x80C, + 0x905, + 0xA0F, + 0xB06, + 0xC0A, + 0xD03, + 0xE09, + 0xF00, + 0x190, + 0x99, + 0x393, + 0x29A, + 0x596, + 0x49F, + 0x795, + 0x69C, + 0x99C, + 0x895, + 0xB9F, + 0xA96, + 0xD9A, + 0xC93, + 0xF99, + 0xE90, + 0x230, + 0x339, + 0x33, + 0x13A, + 0x636, + 0x73F, + 0x435, + 0x53C, + 0xA3C, + 0xB35, + 0x83F, + 0x936, + 0xE3A, + 0xF33, + 0xC39, + 0xD30, + 0x3A0, + 0x2A9, + 0x1A3, + 0xAA, + 0x7A6, + 0x6AF, + 0x5A5, + 0x4AC, + 0xBAC, + 0xAA5, + 0x9AF, + 0x8A6, + 0xFAA, + 0xEA3, + 0xDA9, + 0xCA0, + 0x460, + 0x569, + 0x663, + 0x76A, + 0x66, + 0x16F, + 0x265, + 0x36C, + 0xC6C, + 0xD65, + 0xE6F, + 0xF66, + 0x86A, + 0x963, + 0xA69, + 0xB60, + 0x5F0, + 0x4F9, + 0x7F3, + 0x6FA, + 0x1F6, + 0xFF, + 0x3F5, + 0x2FC, + 0xDFC, + 0xCF5, + 0xFFF, + 0xEF6, + 0x9FA, + 0x8F3, + 0xBF9, + 0xAF0, + 0x650, + 0x759, + 0x453, + 0x55A, + 0x256, + 0x35F, + 0x55, + 0x15C, + 0xE5C, + 0xF55, + 0xC5F, + 0xD56, + 0xA5A, + 0xB53, + 0x859, + 0x950, + 0x7C0, + 0x6C9, + 0x5C3, + 0x4CA, + 0x3C6, + 0x2CF, + 0x1C5, + 0xCC, + 0xFCC, + 0xEC5, + 0xDCF, + 0xCC6, + 0xBCA, + 0xAC3, + 0x9C9, + 0x8C0, + 0x8C0, + 0x9C9, + 0xAC3, + 0xBCA, + 0xCC6, + 0xDCF, + 0xEC5, + 0xFCC, + 0xCC, + 0x1C5, + 0x2CF, + 0x3C6, + 0x4CA, + 0x5C3, + 0x6C9, + 0x7C0, + 0x950, + 0x859, + 0xB53, + 0xA5A, + 0xD56, + 0xC5F, + 0xF55, + 0xE5C, + 0x15C, + 0x55, + 0x35F, + 0x256, + 0x55A, + 0x453, + 0x759, + 0x650, + 0xAF0, + 0xBF9, + 0x8F3, + 0x9FA, + 0xEF6, + 0xFFF, + 0xCF5, + 0xDFC, + 0x2FC, + 0x3F5, + 0xFF, + 0x1F6, + 0x6FA, + 0x7F3, + 0x4F9, + 0x5F0, + 0xB60, + 0xA69, + 0x963, + 0x86A, + 0xF66, + 0xE6F, + 0xD65, + 0xC6C, + 0x36C, + 0x265, + 0x16F, + 0x66, + 0x76A, + 0x663, + 0x569, + 0x460, + 0xCA0, + 0xDA9, + 0xEA3, + 0xFAA, + 0x8A6, + 0x9AF, + 0xAA5, + 0xBAC, + 0x4AC, + 0x5A5, + 0x6AF, + 0x7A6, + 0xAA, + 0x1A3, + 0x2A9, + 0x3A0, + 0xD30, + 0xC39, + 0xF33, + 0xE3A, + 0x936, + 0x83F, + 0xB35, + 0xA3C, + 0x53C, + 0x435, + 0x73F, + 0x636, + 0x13A, + 0x33, + 0x339, + 0x230, + 0xE90, + 0xF99, + 0xC93, + 0xD9A, + 0xA96, + 0xB9F, + 0x895, + 0x99C, + 0x69C, + 0x795, + 0x49F, + 0x596, + 0x29A, + 0x393, + 0x99, + 0x190, + 0xF00, + 0xE09, + 0xD03, + 0xC0A, + 0xB06, + 0xA0F, + 0x905, + 0x80C, + 0x70C, + 0x605, + 0x50F, + 0x406, + 0x30A, + 0x203, + 0x109, + 0x0, +] + +# Maps each edge (by index) to the corresponding cube vertices +EDGE_TO_VERTICES = [ + [0, 1], + [1, 2], + [3, 2], + [0, 3], + [4, 5], + [5, 6], + [7, 6], + [4, 7], + [0, 4], + [1, 5], + [2, 6], + [3, 7], +] + +# A list of lists mapping a cube_index (a given configuration) +# to a list of faces corresponding to that configuration. Each face is represented +# by 3 consecutive numbers. A configuration will at most have 5 faces. +# +# Table taken from http://paulbourke.net/geometry/polygonise/ +FACE_TABLE = [ + [], + [0, 8, 3], + [0, 1, 9], + [1, 8, 3, 9, 8, 1], + [1, 2, 10], + [0, 8, 3, 1, 2, 10], + [9, 2, 10, 0, 2, 9], + [2, 8, 3, 2, 10, 8, 10, 9, 8], + [3, 11, 2], + [0, 11, 2, 8, 11, 0], + [1, 9, 0, 2, 3, 11], + [1, 11, 2, 1, 9, 11, 9, 8, 11], + [3, 10, 1, 11, 10, 3], + [0, 10, 1, 0, 8, 10, 8, 11, 10], + [3, 9, 0, 3, 11, 9, 11, 10, 9], + [9, 8, 10, 10, 8, 11], + [4, 7, 8], + [4, 3, 0, 7, 3, 4], + [0, 1, 9, 8, 4, 7], + [4, 1, 9, 4, 7, 1, 7, 3, 1], + [1, 2, 10, 8, 4, 7], + [3, 4, 7, 3, 0, 4, 1, 2, 10], + [9, 2, 10, 9, 0, 2, 8, 4, 7], + [2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4], + [8, 4, 7, 3, 11, 2], + [11, 4, 7, 11, 2, 4, 2, 0, 4], + [9, 0, 1, 8, 4, 7, 2, 3, 11], + [4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1], + [3, 10, 1, 3, 11, 10, 7, 8, 4], + [1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4], + [4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3], + [4, 7, 11, 4, 11, 9, 9, 11, 10], + [9, 5, 4], + [9, 5, 4, 0, 8, 3], + [0, 5, 4, 1, 5, 0], + [8, 5, 4, 8, 3, 5, 3, 1, 5], + [1, 2, 10, 9, 5, 4], + [3, 0, 8, 1, 2, 10, 4, 9, 5], + [5, 2, 10, 5, 4, 2, 4, 0, 2], + [2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8], + [9, 5, 4, 2, 3, 11], + [0, 11, 2, 0, 8, 11, 4, 9, 5], + [0, 5, 4, 0, 1, 5, 2, 3, 11], + [2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5], + [10, 3, 11, 10, 1, 3, 9, 5, 4], + [4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10], + [5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3], + [5, 4, 8, 5, 8, 10, 10, 8, 11], + [9, 7, 8, 5, 7, 9], + [9, 3, 0, 9, 5, 3, 5, 7, 3], + [0, 7, 8, 0, 1, 7, 1, 5, 7], + [1, 5, 3, 3, 5, 7], + [9, 7, 8, 9, 5, 7, 10, 1, 2], + [10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3], + [8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2], + [2, 10, 5, 2, 5, 3, 3, 5, 7], + [7, 9, 5, 7, 8, 9, 3, 11, 2], + [9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11], + [2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7], + [11, 2, 1, 11, 1, 7, 7, 1, 5], + [9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11], + [5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0], + [11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0], + [11, 10, 5, 7, 11, 5], + [10, 6, 5], + [0, 8, 3, 5, 10, 6], + [9, 0, 1, 5, 10, 6], + [1, 8, 3, 1, 9, 8, 5, 10, 6], + [1, 6, 5, 2, 6, 1], + [1, 6, 5, 1, 2, 6, 3, 0, 8], + [9, 6, 5, 9, 0, 6, 0, 2, 6], + [5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8], + [2, 3, 11, 10, 6, 5], + [11, 0, 8, 11, 2, 0, 10, 6, 5], + [0, 1, 9, 2, 3, 11, 5, 10, 6], + [5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11], + [6, 3, 11, 6, 5, 3, 5, 1, 3], + [0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6], + [3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9], + [6, 5, 9, 6, 9, 11, 11, 9, 8], + [5, 10, 6, 4, 7, 8], + [4, 3, 0, 4, 7, 3, 6, 5, 10], + [1, 9, 0, 5, 10, 6, 8, 4, 7], + [10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4], + [6, 1, 2, 6, 5, 1, 4, 7, 8], + [1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7], + [8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6], + [7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9], + [3, 11, 2, 7, 8, 4, 10, 6, 5], + [5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11], + [0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6], + [9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6], + [8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6], + [5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11], + [0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7], + [6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9], + [10, 4, 9, 6, 4, 10], + [4, 10, 6, 4, 9, 10, 0, 8, 3], + [10, 0, 1, 10, 6, 0, 6, 4, 0], + [8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10], + [1, 4, 9, 1, 2, 4, 2, 6, 4], + [3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4], + [0, 2, 4, 4, 2, 6], + [8, 3, 2, 8, 2, 4, 4, 2, 6], + [10, 4, 9, 10, 6, 4, 11, 2, 3], + [0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6], + [3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10], + [6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1], + [9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3], + [8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1], + [3, 11, 6, 3, 6, 0, 0, 6, 4], + [6, 4, 8, 11, 6, 8], + [7, 10, 6, 7, 8, 10, 8, 9, 10], + [0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10], + [10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0], + [10, 6, 7, 10, 7, 1, 1, 7, 3], + [1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7], + [2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9], + [7, 8, 0, 7, 0, 6, 6, 0, 2], + [7, 3, 2, 6, 7, 2], + [2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7], + [2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7], + [1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11], + [11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1], + [8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6], + [0, 9, 1, 11, 6, 7], + [7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0], + [7, 11, 6], + [7, 6, 11], + [3, 0, 8, 11, 7, 6], + [0, 1, 9, 11, 7, 6], + [8, 1, 9, 8, 3, 1, 11, 7, 6], + [10, 1, 2, 6, 11, 7], + [1, 2, 10, 3, 0, 8, 6, 11, 7], + [2, 9, 0, 2, 10, 9, 6, 11, 7], + [6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8], + [7, 2, 3, 6, 2, 7], + [7, 0, 8, 7, 6, 0, 6, 2, 0], + [2, 7, 6, 2, 3, 7, 0, 1, 9], + [1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6], + [10, 7, 6, 10, 1, 7, 1, 3, 7], + [10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8], + [0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7], + [7, 6, 10, 7, 10, 8, 8, 10, 9], + [6, 8, 4, 11, 8, 6], + [3, 6, 11, 3, 0, 6, 0, 4, 6], + [8, 6, 11, 8, 4, 6, 9, 0, 1], + [9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6], + [6, 8, 4, 6, 11, 8, 2, 10, 1], + [1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6], + [4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9], + [10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3], + [8, 2, 3, 8, 4, 2, 4, 6, 2], + [0, 4, 2, 4, 6, 2], + [1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8], + [1, 9, 4, 1, 4, 2, 2, 4, 6], + [8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1], + [10, 1, 0, 10, 0, 6, 6, 0, 4], + [4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3], + [10, 9, 4, 6, 10, 4], + [4, 9, 5, 7, 6, 11], + [0, 8, 3, 4, 9, 5, 11, 7, 6], + [5, 0, 1, 5, 4, 0, 7, 6, 11], + [11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5], + [9, 5, 4, 10, 1, 2, 7, 6, 11], + [6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5], + [7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2], + [3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6], + [7, 2, 3, 7, 6, 2, 5, 4, 9], + [9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7], + [3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0], + [6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8], + [9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7], + [1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4], + [4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10], + [7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10], + [6, 9, 5, 6, 11, 9, 11, 8, 9], + [3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5], + [0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11], + [6, 11, 3, 6, 3, 5, 5, 3, 1], + [1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6], + [0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10], + [11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5], + [6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3], + [5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2], + [9, 5, 6, 9, 6, 0, 0, 6, 2], + [1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8], + [1, 5, 6, 2, 1, 6], + [1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6], + [10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0], + [0, 3, 8, 5, 6, 10], + [10, 5, 6], + [11, 5, 10, 7, 5, 11], + [11, 5, 10, 11, 7, 5, 8, 3, 0], + [5, 11, 7, 5, 10, 11, 1, 9, 0], + [10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1], + [11, 1, 2, 11, 7, 1, 7, 5, 1], + [0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11], + [9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7], + [7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2], + [2, 5, 10, 2, 3, 5, 3, 7, 5], + [8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5], + [9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2], + [9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2], + [1, 3, 5, 3, 7, 5], + [0, 8, 7, 0, 7, 1, 1, 7, 5], + [9, 0, 3, 9, 3, 5, 5, 3, 7], + [9, 8, 7, 5, 9, 7], + [5, 8, 4, 5, 10, 8, 10, 11, 8], + [5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0], + [0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5], + [10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4], + [2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8], + [0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11], + [0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5], + [9, 4, 5, 2, 11, 3], + [2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4], + [5, 10, 2, 5, 2, 4, 4, 2, 0], + [3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9], + [5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2], + [8, 4, 5, 8, 5, 3, 3, 5, 1], + [0, 4, 5, 1, 0, 5], + [8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5], + [9, 4, 5], + [4, 11, 7, 4, 9, 11, 9, 10, 11], + [0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11], + [1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11], + [3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4], + [4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2], + [9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3], + [11, 7, 4, 11, 4, 2, 2, 4, 0], + [11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4], + [2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9], + [9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7], + [3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10], + [1, 10, 2, 8, 7, 4], + [4, 9, 1, 4, 1, 7, 7, 1, 3], + [4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1], + [4, 0, 3, 7, 4, 3], + [4, 8, 7], + [9, 10, 8, 10, 11, 8], + [3, 0, 9, 3, 9, 11, 11, 9, 10], + [0, 1, 10, 0, 10, 8, 8, 10, 11], + [3, 1, 10, 11, 3, 10], + [1, 2, 11, 1, 11, 9, 9, 11, 8], + [3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9], + [0, 2, 11, 8, 0, 11], + [3, 2, 11], + [2, 3, 8, 2, 8, 10, 10, 8, 9], + [9, 10, 2, 0, 9, 2], + [2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8], + [1, 10, 2], + [1, 3, 8, 9, 1, 8], + [0, 9, 1], + [0, 3, 8], + [], +] diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 695af174..0d2e330c 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -414,7 +414,7 @@ class Transform3d: class Translate(Transform3d): - def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"): + def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"): """ Create a new Transform3d representing 3D translations. @@ -448,7 +448,7 @@ class Translate(Transform3d): class Scale(Transform3d): - def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"): + def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"): """ A Transform3d representing a scaling operation, with different scale factors along each coordinate axis. @@ -489,7 +489,7 @@ class Scale(Transform3d): class Rotate(Transform3d): def __init__( - self, R, dtype=torch.float32, device: str = "cpu", orthogonal_tol: float = 1e-5 + self, R, dtype=torch.float32, device="cpu", orthogonal_tol: float = 1e-5 ): """ Create a new Transform3d representing 3D rotation using a rotation @@ -528,7 +528,7 @@ class RotateAxisAngle(Rotate): axis: str = "X", degrees: bool = True, dtype=torch.float64, - device: str = "cpu", + device="cpu", ): """ Create a new Transform3d representing 3D rotation about an axis @@ -635,7 +635,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal return xyz -def _handle_angle_input(x, dtype, device: str, name: str): +def _handle_angle_input(x, dtype, device, name: str): """ Helper function for building a rotation function using angles. The output is always of shape (N,). diff --git a/tests/bm_marching_cubes.py b/tests/bm_marching_cubes.py new file mode 100644 index 00000000..288b8345 --- /dev/null +++ b/tests/bm_marching_cubes.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from fvcore.common.benchmark import benchmark +from test_marching_cubes import TestMarchingCubes + + +def bm_marching_cubes() -> None: + kwargs_list = [ + {"batch_size": 1, "V": 5}, + {"batch_size": 1, "V": 10}, + {"batch_size": 1, "V": 20}, + {"batch_size": 1, "V": 40}, + {"batch_size": 5, "V": 5}, + {"batch_size": 20, "V": 20}, + ] + benchmark( + TestMarchingCubes.marching_cubes_with_init, + "MARCHING_CUBES", + kwargs_list, + warmup_iters=1, + ) + + +if __name__ == "__main__": + bm_marching_cubes() diff --git a/tests/data/test_marching_cubes_data/double_ellipsoid.pickle b/tests/data/test_marching_cubes_data/double_ellipsoid.pickle new file mode 100644 index 00000000..3642a73d Binary files /dev/null and b/tests/data/test_marching_cubes_data/double_ellipsoid.pickle differ diff --git a/tests/data/test_marching_cubes_data/sphere_level64.pickle b/tests/data/test_marching_cubes_data/sphere_level64.pickle new file mode 100644 index 00000000..2cfafe81 Binary files /dev/null and b/tests/data/test_marching_cubes_data/sphere_level64.pickle differ diff --git a/tests/test_marching_cubes.py b/tests/test_marching_cubes.py new file mode 100644 index 00000000..182b01b4 --- /dev/null +++ b/tests/test_marching_cubes.py @@ -0,0 +1,771 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import os +import pickle +import unittest +from pathlib import Path + +import torch +from common_testing import TestCaseMixin +from pytorch3d.ops.marching_cubes import marching_cubes_naive + + +USE_SCIKIT = False + + +def convert_to_local(verts, volume_dim): + return (2 * verts) / (volume_dim - 1) - 1 + + +class TestCubeConfiguration(TestCaseMixin, unittest.TestCase): + + # Test single cubes. Each case corresponds to the corresponding + # cube vertex configuration in each case here (0-indexed): + # https://en.wikipedia.org/wiki/Marching_cubes#/media/File:MarchingCubes.svg + + def test_empty_volume(self): # case 0 + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor([]) + expected_faces = torch.tensor([], dtype=torch.int64) + self.assertClose(verts, expected_verts) + self.assertClose(faces, expected_faces) + + def test_case1(self): # case 1 + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 0, 0, 0] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [0.5, 0, 0], + [0, 0, 0.5], + [0, 0.5, 0], + ] + ) + + expected_faces = torch.tensor([[1, 2, 0]]) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case2(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 0:2, 0, 0] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [1.0000, 0.0000, 0.5000], + [0.0000, 0.0000, 0.5000], + [1.0000, 0.5000, 0.0000], + [0.0000, 0.5000, 0.0000], + ] + ) + expected_faces = torch.tensor([[1, 2, 0], [3, 2, 1]]) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case3(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 0, 0, 0] = 0 + volume_data[0, 1, 1, 0] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [0.5000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.5000], + [1.0000, 1.0000, 0.5000], + [0.5000, 1.0000, 0.0000], + [1.0000, 0.5000, 0.0000], + [0.0000, 0.5000, 0.0000], + ] + ) + expected_faces = torch.tensor([[0, 1, 5], [4, 3, 2]]) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case4(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 1, 0, 0] = 0 + volume_data[0, 1, 0, 1] = 0 + volume_data[0, 0, 0, 1] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [0.5000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.5000], + [0.0000, 0.5000, 1.0000], + [1.0000, 0.5000, 1.0000], + [1.0000, 0.5000, 0.0000], + ] + ) + expected_faces = torch.tensor([[0, 2, 1], [0, 4, 2], [4, 3, 2]]) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case5(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 0:2, 0, 0:2] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [0.0000, 0.5000, 1.0000], + [1.0000, 0.5000, 1.0000], + [1.0000, 0.5000, 0.0000], + [0.0000, 0.5000, 0.0000], + ] + ) + + expected_faces = torch.tensor([[1, 0, 2], [2, 0, 3]]) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case6(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 1, 0, 0] = 0 + volume_data[0, 1, 0, 1] = 0 + volume_data[0, 0, 0, 1] = 0 + volume_data[0, 0, 1, 0] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [0.5000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.5000], + [0.5000, 1.0000, 0.0000], + [0.0000, 1.0000, 0.5000], + [0.0000, 0.5000, 1.0000], + [1.0000, 0.5000, 1.0000], + [1.0000, 0.5000, 0.0000], + [0.0000, 0.5000, 0.0000], + ] + ) + expected_faces = torch.tensor([[2, 7, 3], [0, 6, 1], [6, 4, 1], [6, 5, 4]]) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case7(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 0, 0, 0] = 0 + volume_data[0, 1, 0, 1] = 0 + volume_data[0, 1, 1, 0] = 0 + volume_data[0, 0, 1, 1] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [0.5000, 0.0000, 1.0000], + [1.0000, 0.0000, 0.5000], + [0.5000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.5000], + [0.5000, 1.0000, 1.0000], + [1.0000, 1.0000, 0.5000], + [0.5000, 1.0000, 0.0000], + [0.0000, 1.0000, 0.5000], + [0.0000, 0.5000, 1.0000], + [1.0000, 0.5000, 1.0000], + [1.0000, 0.5000, 0.0000], + [0.0000, 0.5000, 0.0000], + ] + ) + + expected_faces = torch.tensor([[0, 1, 9], [4, 7, 8], [2, 3, 11], [5, 10, 6]]) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case8(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 0, 0, 0] = 0 + volume_data[0, 0, 0, 1] = 0 + volume_data[0, 1, 0, 1] = 0 + volume_data[0, 0, 1, 1] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [1.0000, 0.0000, 0.5000], + [0.5000, 0.0000, 0.0000], + [0.5000, 1.0000, 1.0000], + [0.0000, 1.0000, 0.5000], + [1.0000, 0.5000, 1.0000], + [0.0000, 0.5000, 0.0000], + ] + ) + expected_faces = torch.tensor([[2, 3, 5], [4, 2, 5], [4, 5, 1], [4, 1, 0]]) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case9(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 1, 0, 0] = 0 + volume_data[0, 0, 0, 1] = 0 + volume_data[0, 1, 0, 1] = 0 + volume_data[0, 0, 1, 1] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [0.5000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.5000], + [0.5000, 1.0000, 1.0000], + [0.0000, 1.0000, 0.5000], + [1.0000, 0.5000, 1.0000], + [1.0000, 0.5000, 0.0000], + ] + ) + expected_faces = torch.tensor([[0, 5, 4], [0, 4, 3], [0, 3, 1], [3, 4, 2]]) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case10(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 0, 0, 0] = 0 + volume_data[0, 1, 1, 1] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [0.5000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.5000], + [0.5000, 1.0000, 1.0000], + [1.0000, 1.0000, 0.5000], + [1.0000, 0.5000, 1.0000], + [0.0000, 0.5000, 0.0000], + ] + ) + + expected_faces = torch.tensor([[4, 3, 2], [0, 1, 5]]) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case11(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 0, 0, 0] = 0 + volume_data[0, 1, 0, 0] = 0 + volume_data[0, 1, 1, 1] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [1.0000, 0.0000, 0.5000], + [0.0000, 0.0000, 0.5000], + [0.5000, 1.0000, 1.0000], + [1.0000, 1.0000, 0.5000], + [1.0000, 0.5000, 1.0000], + [1.0000, 0.5000, 0.0000], + [0.0000, 0.5000, 0.0000], + ] + ) + + expected_faces = torch.tensor([[5, 1, 6], [5, 0, 1], [4, 3, 2]]) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case12(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 1, 0, 0] = 0 + volume_data[0, 0, 1, 0] = 0 + volume_data[0, 1, 1, 1] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [1.0000, 0.0000, 0.5000], + [0.5000, 0.0000, 0.0000], + [0.5000, 1.0000, 1.0000], + [1.0000, 1.0000, 0.5000], + [0.5000, 1.0000, 0.0000], + [0.0000, 1.0000, 0.5000], + [1.0000, 0.5000, 1.0000], + [1.0000, 0.5000, 0.0000], + [0.0000, 0.5000, 0.0000], + ] + ) + + expected_faces = torch.tensor([[6, 3, 2], [7, 0, 1], [5, 4, 8]]) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case13(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 0, 0, 0] = 0 + volume_data[0, 0, 1, 0] = 0 + volume_data[0, 1, 0, 1] = 0 + volume_data[0, 1, 1, 1] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [0.5000, 0.0000, 1.0000], + [1.0000, 0.0000, 0.5000], + [0.5000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.5000], + [0.5000, 1.0000, 1.0000], + [1.0000, 1.0000, 0.5000], + [0.5000, 1.0000, 0.0000], + [0.0000, 1.0000, 0.5000], + ] + ) + + expected_faces = torch.tensor([[3, 6, 2], [3, 7, 6], [1, 5, 0], [5, 4, 0]]) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + def test_case14(self): + volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) + volume_data[0, 0, 0, 0] = 0 + volume_data[0, 0, 0, 1] = 0 + volume_data[0, 1, 0, 1] = 0 + volume_data[0, 1, 1, 1] = 0 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [1.0000, 0.0000, 0.5000], + [0.5000, 0.0000, 0.0000], + [0.5000, 1.0000, 1.0000], + [1.0000, 1.0000, 0.5000], + [0.0000, 0.5000, 1.0000], + [0.0000, 0.5000, 0.0000], + ] + ) + + expected_faces = torch.tensor([[1, 0, 3], [1, 3, 4], [1, 4, 5], [2, 4, 3]]) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 2) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + +class TestMarchingCubes(TestCaseMixin, unittest.TestCase): + def test_single_point(self): + volume_data = torch.zeros(1, 3, 3, 3) # (B, W, H, D) + volume_data[0, 1, 1, 1] = 1 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [0.5, 1, 1], + [1, 1, 0.5], + [1, 0.5, 1], + [1, 1, 1.5], + [1, 1.5, 1], + [1.5, 1, 1], + ] + ) + expected_faces = torch.tensor( + [ + [2, 0, 1], + [2, 3, 0], + [0, 4, 1], + [3, 4, 0], + [5, 2, 1], + [3, 2, 5], + [5, 1, 4], + [3, 5, 4], + ] + ) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 3) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all()) + + def test_cube(self): + volume_data = torch.zeros(1, 5, 5, 5) # (B, W, H, D) + volume_data[0, 1, 1, 1] = 1 + volume_data[0, 1, 1, 2] = 1 + volume_data[0, 2, 1, 1] = 1 + volume_data[0, 2, 1, 2] = 1 + volume_data[0, 1, 2, 1] = 1 + volume_data[0, 1, 2, 2] = 1 + volume_data[0, 2, 2, 1] = 1 + volume_data[0, 2, 2, 2] = 1 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [0.9000, 1.0000, 1.0000], + [1.0000, 1.0000, 0.9000], + [1.0000, 0.9000, 1.0000], + [0.9000, 1.0000, 2.0000], + [1.0000, 0.9000, 2.0000], + [1.0000, 1.0000, 2.1000], + [0.9000, 2.0000, 1.0000], + [1.0000, 2.0000, 0.9000], + [0.9000, 2.0000, 2.0000], + [1.0000, 2.0000, 2.1000], + [1.0000, 2.1000, 1.0000], + [1.0000, 2.1000, 2.0000], + [2.0000, 1.0000, 0.9000], + [2.0000, 0.9000, 1.0000], + [2.0000, 0.9000, 2.0000], + [2.0000, 1.0000, 2.1000], + [2.0000, 2.0000, 0.9000], + [2.0000, 2.0000, 2.1000], + [2.0000, 2.1000, 1.0000], + [2.0000, 2.1000, 2.0000], + [2.1000, 1.0000, 1.0000], + [2.1000, 1.0000, 2.0000], + [2.1000, 2.0000, 1.0000], + [2.1000, 2.0000, 2.0000], + ] + ) + + expected_faces = torch.tensor( + [ + [2, 0, 1], + [2, 4, 3], + [0, 2, 3], + [4, 5, 3], + [0, 6, 7], + [1, 0, 7], + [3, 8, 0], + [8, 6, 0], + [5, 9, 8], + [3, 5, 8], + [6, 10, 7], + [11, 10, 6], + [8, 11, 6], + [9, 11, 8], + [13, 2, 1], + [12, 13, 1], + [14, 4, 13], + [13, 4, 2], + [4, 14, 15], + [5, 4, 15], + [12, 1, 16], + [1, 7, 16], + [15, 17, 5], + [5, 17, 9], + [16, 7, 10], + [18, 16, 10], + [19, 18, 11], + [18, 10, 11], + [9, 17, 19], + [11, 9, 19], + [20, 13, 12], + [20, 21, 14], + [13, 20, 14], + [15, 14, 21], + [22, 20, 12], + [16, 22, 12], + [21, 20, 23], + [23, 20, 22], + [17, 15, 21], + [23, 17, 21], + [22, 16, 18], + [23, 22, 18], + [19, 23, 18], + [17, 23, 19], + ] + ) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 5) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + # Check all values are in the range [-1, 1] + self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all()) + + def test_cube_no_duplicate_verts(self): + volume_data = torch.zeros(1, 5, 5, 5) # (B, W, H, D) + volume_data[0, 1, 1, 1] = 1 + volume_data[0, 1, 1, 2] = 1 + volume_data[0, 2, 1, 1] = 1 + volume_data[0, 2, 1, 2] = 1 + volume_data[0, 1, 2, 1] = 1 + volume_data[0, 1, 2, 2] = 1 + volume_data[0, 2, 2, 1] = 1 + volume_data[0, 2, 2, 2] = 1 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, 1, return_local_coords=False) + + expected_verts = torch.tensor( + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 2.0], + [1.0, 2.0, 1.0], + [1.0, 2.0, 2.0], + [2.0, 1.0, 1.0], + [2.0, 1.0, 2.0], + [2.0, 2.0, 1.0], + [2.0, 2.0, 2.0], + ] + ) + + expected_faces = torch.tensor( + [ + [1, 3, 0], + [3, 2, 0], + [5, 1, 4], + [4, 1, 0], + [4, 0, 6], + [0, 2, 6], + [5, 7, 1], + [1, 7, 3], + [7, 6, 3], + [6, 2, 3], + [5, 4, 7], + [7, 4, 6], + ] + ) + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive(volume_data, 1, return_local_coords=True) + expected_verts = convert_to_local(expected_verts, 5) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + # Check all values are in the range [-1, 1] + self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all()) + + def test_sphere(self): + # (B, W, H, D) + volume = torch.Tensor( + [ + [ + [(x - 10) ** 2 + (y - 10) ** 2 + (z - 10) ** 2 for z in range(20)] + for y in range(20) + ] + for x in range(20) + ] + ).unsqueeze(0) + volume = volume.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive( + volume, isolevel=64, return_local_coords=False + ) + + DATA_DIR = Path(__file__).resolve().parent / "data" + data_filename = "test_marching_cubes_data/sphere_level64.pickle" + filename = os.path.join(DATA_DIR, data_filename) + with open(filename, "rb") as file: + verts_and_faces = pickle.load(file) + expected_verts = verts_and_faces["verts"].squeeze() + expected_faces = verts_and_faces["faces"].squeeze() + + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + verts, faces = marching_cubes_naive( + volume, isolevel=64, return_local_coords=True + ) + + expected_verts = convert_to_local(expected_verts, 20) + self.assertClose(verts[0], expected_verts) + self.assertClose(faces[0], expected_faces) + + # Check all values are in the range [-1, 1] + self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all()) + + # Uses skimage.draw.ellipsoid + def test_double_ellipsoid(self): + if USE_SCIKIT: + import numpy as np + from skimage.draw import ellipsoid + + ellip_base = ellipsoid(6, 10, 16, levelset=True) + ellip_double = np.concatenate( + (ellip_base[:-1, ...], ellip_base[2:, ...]), axis=0 + ) + volume = torch.Tensor(ellip_double).unsqueeze(0) + volume = volume.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume, isolevel=0.001) + + DATA_DIR = Path(__file__).resolve().parent / "data" + data_filename = "test_marching_cubes_data/double_ellipsoid.pickle" + filename = os.path.join(DATA_DIR, data_filename) + with open(filename, "rb") as file: + verts_and_faces = pickle.load(file) + expected_verts = verts_and_faces["verts"] + expected_faces = verts_and_faces["faces"] + + self.assertClose(verts[0], expected_verts[0]) + self.assertClose(faces[0], expected_faces[0]) + + def test_cube_surface_area(self): + if USE_SCIKIT: + from skimage.measure import marching_cubes_classic, mesh_surface_area + + volume_data = torch.zeros(1, 5, 5, 5) + volume_data[0, 1, 1, 1] = 1 + volume_data[0, 1, 1, 2] = 1 + volume_data[0, 2, 1, 1] = 1 + volume_data[0, 2, 1, 2] = 1 + volume_data[0, 1, 2, 1] = 1 + volume_data[0, 1, 2, 2] = 1 + volume_data[0, 2, 2, 1] = 1 + volume_data[0, 2, 2, 2] = 1 + volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) + verts_sci, faces_sci = marching_cubes_classic(volume_data[0]) + + surf = mesh_surface_area(verts[0], faces[0]) + surf_sci = mesh_surface_area(verts_sci, faces_sci) + + self.assertClose(surf, surf_sci) + + def test_sphere_surface_area(self): + if USE_SCIKIT: + from skimage.measure import marching_cubes_classic, mesh_surface_area + + # (B, W, H, D) + volume = torch.Tensor( + [ + [ + [ + (x - 10) ** 2 + (y - 10) ** 2 + (z - 10) ** 2 + for z in range(20) + ] + for y in range(20) + ] + for x in range(20) + ] + ).unsqueeze(0) + volume = volume.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume, isolevel=64) + verts_sci, faces_sci = marching_cubes_classic(volume[0], level=64) + + surf = mesh_surface_area(verts[0], faces[0]) + surf_sci = mesh_surface_area(verts_sci, faces_sci) + + self.assertClose(surf, surf_sci) + + def test_double_ellipsoid_surface_area(self): + if USE_SCIKIT: + import numpy as np + from skimage.draw import ellipsoid + from skimage.measure import marching_cubes_classic, mesh_surface_area + + ellip_base = ellipsoid(6, 10, 16, levelset=True) + ellip_double = np.concatenate( + (ellip_base[:-1, ...], ellip_base[2:, ...]), axis=0 + ) + volume = torch.Tensor(ellip_double).unsqueeze(0) + volume = volume.permute(0, 3, 2, 1) # (B, D, H, W) + verts, faces = marching_cubes_naive(volume, isolevel=0) + verts_sci, faces_sci = marching_cubes_classic(volume[0], level=0) + + surf = mesh_surface_area(verts[0], faces[0]) + surf_sci = mesh_surface_area(verts_sci, faces_sci) + + self.assertClose(surf, surf_sci) + + @staticmethod + def marching_cubes_with_init(batch_size: int, V: int): + device = torch.device("cuda:0") + volume_data = torch.rand( + (batch_size, V, V, V), dtype=torch.float32, device=device + ) + torch.cuda.synchronize() + + def convert(): + marching_cubes_naive(volume_data, return_local_coords=False) + torch.cuda.synchronize() + + return convert