mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Classic Marching Cubes algorithm implementation
Summary: Defines a function to run marching cubes algorithm on a single or batch of 3D scalar fields. Returns a mesh's faces and vertices. UPDATES (12/18) - Input data is now specified as a (B, D, H, W) tensor as opposed to a (B, W, H, D) tensor. This will now be compatible with the Volumes datastructure. - Add an option to return output vertices in local coordinates instead of world coordinates. Also added a small fix to remove the dype for device in Transforms3D - if passing in a torch.device instead of str it causes a pyre error. Reviewed By: jcjohnson Differential Revision: D24599019 fbshipit-source-id: 90554a200319fed8736a12371cc349e7108aacd0
This commit is contained in:
		
							parent
							
								
									9c6b58c5ad
								
							
						
					
					
						commit
						ebac66daeb
					
				
							
								
								
									
										347
									
								
								pytorch3d/ops/marching_cubes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										347
									
								
								pytorch3d/ops/marching_cubes.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,347 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from typing import Dict, List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.ops.marching_cubes_data import EDGE_TABLE, EDGE_TO_VERTICES, FACE_TABLE
 | 
			
		||||
from pytorch3d.transforms import Translate
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
EPS = 0.00001
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Cube:
 | 
			
		||||
    def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1):
 | 
			
		||||
        """
 | 
			
		||||
        Initializes a cube given the bottom front left vertex coordinate
 | 
			
		||||
        and the cube spacing
 | 
			
		||||
 | 
			
		||||
        Edge and vertex convention:
 | 
			
		||||
 | 
			
		||||
                    v4_______e4____________v5
 | 
			
		||||
                    /|                    /|
 | 
			
		||||
                   / |                   / |
 | 
			
		||||
                e7/  |                e5/  |
 | 
			
		||||
                 /___|______e6_________/   |
 | 
			
		||||
              v7|    |                 |v6 |e9
 | 
			
		||||
                |    |                 |   |
 | 
			
		||||
                |    |e8               |e10|
 | 
			
		||||
             e11|    |                 |   |
 | 
			
		||||
                |    |_________________|___|
 | 
			
		||||
                |   / v0      e0       |   /v1
 | 
			
		||||
                |  /                   |  /
 | 
			
		||||
                | /e3                  | /e1
 | 
			
		||||
                |/_____________________|/
 | 
			
		||||
                v3         e2          v2
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            bfl_vertex: a tuple of size 3 corresponding to the bottom front left vertex
 | 
			
		||||
                of the cube in (x, y, z) format
 | 
			
		||||
            spacing: the length of each edge of the cube
 | 
			
		||||
        """
 | 
			
		||||
        # match corner orders to algorithm convention
 | 
			
		||||
        if len(bfl_vertex) != 3:
 | 
			
		||||
            msg = "The vertex {} is size {} instead of size 3".format(
 | 
			
		||||
                bfl_vertex, len(bfl_vertex)
 | 
			
		||||
            )
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
        x, y, z = bfl_vertex
 | 
			
		||||
        self.vertices = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [x, y, z + spacing],
 | 
			
		||||
                [x + spacing, y, z + spacing],
 | 
			
		||||
                [x + spacing, y, z],
 | 
			
		||||
                [x, y, z],
 | 
			
		||||
                [x, y + spacing, z + spacing],
 | 
			
		||||
                [x + spacing, y + spacing, z + spacing],
 | 
			
		||||
                [x + spacing, y + spacing, z],
 | 
			
		||||
                [x, y + spacing, z],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_index(self, volume_data: torch.Tensor, isolevel: float) -> int:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates the cube_index in the range 0-255 to index
 | 
			
		||||
        into EDGE_TABLE and FACE_TABLE
 | 
			
		||||
        Args:
 | 
			
		||||
            volume_data: the 3D scalar data
 | 
			
		||||
            isolevel: the isosurface value used as a threshold
 | 
			
		||||
                for determining whether a point is inside/outside
 | 
			
		||||
                the volume
 | 
			
		||||
        """
 | 
			
		||||
        cube_index = 0
 | 
			
		||||
        bit = 1
 | 
			
		||||
        for index in range(len(self.vertices)):
 | 
			
		||||
            vertex = self.vertices[index]
 | 
			
		||||
            value = _get_value(vertex, volume_data)
 | 
			
		||||
            if value < isolevel:
 | 
			
		||||
                cube_index |= bit
 | 
			
		||||
            bit *= 2
 | 
			
		||||
        return cube_index
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def marching_cubes_naive(
 | 
			
		||||
    volume_data_batch: torch.Tensor,
 | 
			
		||||
    isolevel: Optional[float] = None,
 | 
			
		||||
    spacing: int = 1,
 | 
			
		||||
    return_local_coords: bool = True,
 | 
			
		||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
 | 
			
		||||
    """
 | 
			
		||||
    Runs the classic marching cubes algorithm, iterating over
 | 
			
		||||
    the coordinates of the volume_data and using a given isolevel
 | 
			
		||||
    for determining intersected edges of cubes of size `spacing`.
 | 
			
		||||
    Returns vertices and faces of the obtained mesh.
 | 
			
		||||
    This operation is non-differentiable.
 | 
			
		||||
 | 
			
		||||
    This is a naive implementation, and is not optimized for efficiency.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        volume_data_batch: a Tensor of size (N, D, H, W) corresponding to
 | 
			
		||||
            a batch of 3D scalar fields
 | 
			
		||||
        isolevel: the isosurface value to use as the threshold to determine
 | 
			
		||||
            whether points are within a volume. If None, then the average of the
 | 
			
		||||
            maximum and minimum value of the scalar field will be used.
 | 
			
		||||
        spacing: an integer specifying the cube size to use
 | 
			
		||||
        return_local_coords: bool. If True the output vertices will be in local coordinates in
 | 
			
		||||
        the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range
 | 
			
		||||
        [0, W-1] x [0, H-1] x [0, D-1]
 | 
			
		||||
    Returns:
 | 
			
		||||
        verts: [(V_0, 3), (V_1, 3), ...] List of N FloatTensors of vertices.
 | 
			
		||||
        faces: [(F_0, 3), (F_1, 3), ...] List of N LongTensors of faces.
 | 
			
		||||
    """
 | 
			
		||||
    volume_data_batch = volume_data_batch.detach().cpu()
 | 
			
		||||
    batched_verts, batched_faces = [], []
 | 
			
		||||
    D, H, W = volume_data_batch.shape[1:]
 | 
			
		||||
    # pyre-ignore [16]
 | 
			
		||||
    volume_size_xyz = volume_data_batch.new_tensor([W, H, D])[None]
 | 
			
		||||
 | 
			
		||||
    if return_local_coords:
 | 
			
		||||
        # Convert from local coordinates in the range [-1, 1] range to
 | 
			
		||||
        # world coordinates in the range [0, D-1], [0, H-1], [0, W-1]
 | 
			
		||||
        local_to_world_transform = Translate(
 | 
			
		||||
            x=+1.0, y=+1.0, z=+1.0, device=volume_data_batch.device
 | 
			
		||||
        ).scale((volume_size_xyz - 1) * spacing * 0.5)
 | 
			
		||||
        # Perform the inverse to go from world to local
 | 
			
		||||
        world_to_local_transform = local_to_world_transform.inverse()
 | 
			
		||||
 | 
			
		||||
    for i in range(len(volume_data_batch)):
 | 
			
		||||
        volume_data = volume_data_batch[i]
 | 
			
		||||
        curr_isolevel = (
 | 
			
		||||
            ((volume_data.max() + volume_data.min()) / 2).item()
 | 
			
		||||
            if isolevel is None
 | 
			
		||||
            else isolevel
 | 
			
		||||
        )
 | 
			
		||||
        edge_vertices_to_index = {}
 | 
			
		||||
        vertex_coords_to_index = {}
 | 
			
		||||
        verts, faces = [], []
 | 
			
		||||
        # Use length - spacing for the bounds since we are using
 | 
			
		||||
        # cubes of size spacing, with the lowest x,y,z values
 | 
			
		||||
        # (bottom front left)
 | 
			
		||||
        for x in range(0, W - spacing, spacing):
 | 
			
		||||
            for y in range(0, H - spacing, spacing):
 | 
			
		||||
                for z in range(0, D - spacing, spacing):
 | 
			
		||||
                    cube = Cube((x, y, z), spacing)
 | 
			
		||||
                    new_verts, new_faces = polygonise(
 | 
			
		||||
                        cube,
 | 
			
		||||
                        curr_isolevel,
 | 
			
		||||
                        volume_data,
 | 
			
		||||
                        edge_vertices_to_index,
 | 
			
		||||
                        vertex_coords_to_index,
 | 
			
		||||
                    )
 | 
			
		||||
                    verts.extend(new_verts)
 | 
			
		||||
                    faces.extend(new_faces)
 | 
			
		||||
        if len(faces) > 0 and len(verts) > 0:
 | 
			
		||||
            verts = torch.tensor(verts, dtype=torch.float32)
 | 
			
		||||
            # Convert vertices from world to local coords
 | 
			
		||||
            if return_local_coords:
 | 
			
		||||
                verts = world_to_local_transform.transform_points(verts[None, ...])
 | 
			
		||||
                verts = verts.squeeze()
 | 
			
		||||
            batched_verts.append(verts)
 | 
			
		||||
            batched_faces.append(torch.tensor(faces, dtype=torch.int64))
 | 
			
		||||
    return batched_verts, batched_faces
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def polygonise(
 | 
			
		||||
    cube: Cube,
 | 
			
		||||
    isolevel: float,
 | 
			
		||||
    volume_data: torch.Tensor,
 | 
			
		||||
    edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],
 | 
			
		||||
    vertex_coords_to_index: Dict[Tuple[float, float, float], int],
 | 
			
		||||
) -> Tuple[list, list]:
 | 
			
		||||
    """
 | 
			
		||||
    Runs the classic marching cubes algorithm for one Cube in the volume.
 | 
			
		||||
    Returns the vertices and faces for the given cube.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cube: a Cube indicating the cube being examined for edges that intersect
 | 
			
		||||
            the volume data.
 | 
			
		||||
        isolevel: the isosurface value to use as the threshold to determine
 | 
			
		||||
            whether points are within a volume.
 | 
			
		||||
        volume_data: a Tensor of shape (D, H, W) corresponding to
 | 
			
		||||
            a 3D scalar field
 | 
			
		||||
        edge_vertices_to_index: A dictionary which maps an edge's two coordinates
 | 
			
		||||
            to the index of its interpolated point, if that interpolated point
 | 
			
		||||
            has already been used by a previous point
 | 
			
		||||
        vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding
 | 
			
		||||
            index of that vertex, if that point has already been marked as a vertex.
 | 
			
		||||
    Returns:
 | 
			
		||||
        verts: List of triangle vertices for the given cube in the volume
 | 
			
		||||
        faces: List of triangle faces for the given cube in the volume
 | 
			
		||||
    """
 | 
			
		||||
    num_existing_verts = max(edge_vertices_to_index.values(), default=-1) + 1
 | 
			
		||||
    verts, faces = [], []
 | 
			
		||||
    cube_index = cube.get_index(volume_data, isolevel)
 | 
			
		||||
    edges = EDGE_TABLE[cube_index]
 | 
			
		||||
    edge_indices = _get_edge_indices(edges)
 | 
			
		||||
    if len(edge_indices) == 0:
 | 
			
		||||
        return [], []
 | 
			
		||||
 | 
			
		||||
    new_verts, edge_index_to_point_index = _calculate_interp_vertices(
 | 
			
		||||
        edge_indices,
 | 
			
		||||
        volume_data,
 | 
			
		||||
        cube,
 | 
			
		||||
        isolevel,
 | 
			
		||||
        edge_vertices_to_index,
 | 
			
		||||
        vertex_coords_to_index,
 | 
			
		||||
        num_existing_verts,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Create faces
 | 
			
		||||
    face_triangles = FACE_TABLE[cube_index]
 | 
			
		||||
    for i in range(0, len(face_triangles), 3):
 | 
			
		||||
        tri1 = edge_index_to_point_index[face_triangles[i]]
 | 
			
		||||
        tri2 = edge_index_to_point_index[face_triangles[i + 1]]
 | 
			
		||||
        tri3 = edge_index_to_point_index[face_triangles[i + 2]]
 | 
			
		||||
        if tri1 != tri2 and tri2 != tri3 and tri1 != tri3:
 | 
			
		||||
            faces.append([tri1, tri2, tri3])
 | 
			
		||||
 | 
			
		||||
    verts += new_verts
 | 
			
		||||
    return verts, faces
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_edge_indices(edges: int) -> List[int]:
 | 
			
		||||
    """
 | 
			
		||||
    Finds which edge numbers are intersected given the bit representation
 | 
			
		||||
    detailed in marching_cubes_data.EDGE_TABLE.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        edges: an integer corresponding to the value at cube_index
 | 
			
		||||
            from the EDGE_TABLE in marching_cubes_data.py
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        edge_indices: A list of edge indices
 | 
			
		||||
    """
 | 
			
		||||
    if edges == 0:
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    edge_indices = []
 | 
			
		||||
    for i in range(12):
 | 
			
		||||
        if edges & (2 ** i):
 | 
			
		||||
            edge_indices.append(i)
 | 
			
		||||
    return edge_indices
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _calculate_interp_vertices(
 | 
			
		||||
    edge_indices: List[int],
 | 
			
		||||
    volume_data: torch.Tensor,
 | 
			
		||||
    cube: Cube,
 | 
			
		||||
    isolevel: float,
 | 
			
		||||
    edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],
 | 
			
		||||
    vertex_coords_to_index: Dict[Tuple[float, float, float], int],
 | 
			
		||||
    num_existing_verts: int,
 | 
			
		||||
) -> Tuple[List, Dict[int, int]]:
 | 
			
		||||
    """
 | 
			
		||||
    Finds the interpolated vertices for the intersected edges, either referencing
 | 
			
		||||
    previous calculations or newly calculating and storing the new interpolated
 | 
			
		||||
    points.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        edge_indices: the numbers of the edges which are intersected. See the
 | 
			
		||||
            Cube class for more detail on the edge numbering convention.
 | 
			
		||||
        volume_data: a Tensor of size (D, H, W) corresponding to
 | 
			
		||||
            a 3D scalar field
 | 
			
		||||
        cube: a Cube indicating the cube being examined for edges that intersect
 | 
			
		||||
            the volume
 | 
			
		||||
        isolevel: the isosurface value to use as the threshold to determine
 | 
			
		||||
            whether points are within a volume.
 | 
			
		||||
        edge_vertices_to_index: A dictionary which maps an edge's two coordinates
 | 
			
		||||
            to the index of its interpolated point, if that interpolated point
 | 
			
		||||
            has already been used by a previous point
 | 
			
		||||
        vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding
 | 
			
		||||
            index of that vertex, if that point has already been marked as a vertex.
 | 
			
		||||
        num_existing_verts: the number of vertices that have been found in previous
 | 
			
		||||
            calls to polygonise for the given volume_data in the above function, marching_cubes.
 | 
			
		||||
            This is equal to the 1 + the maximum value in edge_vertices_to_index.
 | 
			
		||||
    Returns:
 | 
			
		||||
        interp_points: a list of new interpolated points
 | 
			
		||||
        edge_index_to_point_index: a dictionary mapping an edge number to the index in the
 | 
			
		||||
            marching cubes' vertices list of the interpolated point on that edge. To be precise,
 | 
			
		||||
            it refers to the index within the vertices list after interp_points
 | 
			
		||||
            has been appended to the verts list constructed in the marching_cubes_naive
 | 
			
		||||
            function.
 | 
			
		||||
    """
 | 
			
		||||
    interp_points = []
 | 
			
		||||
    edge_index_to_point_index = {}
 | 
			
		||||
    for edge_index in edge_indices:
 | 
			
		||||
        v1, v2 = EDGE_TO_VERTICES[edge_index]
 | 
			
		||||
        point1, point2 = cube.vertices[v1], cube.vertices[v2]
 | 
			
		||||
        p_tuple1, p_tuple2 = tuple(point1.tolist()), tuple(point2.tolist())
 | 
			
		||||
        if (p_tuple1, p_tuple2) in edge_vertices_to_index:
 | 
			
		||||
            edge_index_to_point_index[edge_index] = edge_vertices_to_index[
 | 
			
		||||
                (p_tuple1, p_tuple2)
 | 
			
		||||
            ]
 | 
			
		||||
        else:
 | 
			
		||||
            val1, val2 = _get_value(point1, volume_data), _get_value(
 | 
			
		||||
                point2, volume_data
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            point = None
 | 
			
		||||
            if abs(isolevel - val1) < EPS:
 | 
			
		||||
                point = point1
 | 
			
		||||
 | 
			
		||||
            if abs(isolevel - val2) < EPS:
 | 
			
		||||
                point = point2
 | 
			
		||||
 | 
			
		||||
            if abs(val1 - val2) < EPS:
 | 
			
		||||
                point = point1
 | 
			
		||||
 | 
			
		||||
            if point is None:
 | 
			
		||||
                mu = (isolevel - val1) / (val2 - val1)
 | 
			
		||||
                x1, y1, z1 = point1
 | 
			
		||||
                x2, y2, z2 = point2
 | 
			
		||||
                x = x1 + mu * (x2 - x1)
 | 
			
		||||
                y = y1 + mu * (y2 - y1)
 | 
			
		||||
                z = z1 + mu * (z2 - z1)
 | 
			
		||||
            else:
 | 
			
		||||
                x, y, z = point
 | 
			
		||||
 | 
			
		||||
            x, y, z = x.item(), y.item(), z.item()  # for dictionary keys
 | 
			
		||||
 | 
			
		||||
            vert_index = None
 | 
			
		||||
            if (x, y, z) in vertex_coords_to_index:
 | 
			
		||||
                vert_index = vertex_coords_to_index[(x, y, z)]
 | 
			
		||||
            else:
 | 
			
		||||
                vert_index = num_existing_verts + len(interp_points)
 | 
			
		||||
                interp_points.append([x, y, z])
 | 
			
		||||
                vertex_coords_to_index[(x, y, z)] = vert_index
 | 
			
		||||
 | 
			
		||||
            edge_vertices_to_index[(p_tuple1, p_tuple2)] = vert_index
 | 
			
		||||
            edge_index_to_point_index[edge_index] = vert_index
 | 
			
		||||
 | 
			
		||||
    return interp_points, edge_index_to_point_index
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_value(point: Tuple[int, int, int], volume_data: torch.Tensor) -> float:
 | 
			
		||||
    """
 | 
			
		||||
    Gets the value at a given coordinate point in the scalar field.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        point: data of shape (3) corresponding to an xyz coordinate.
 | 
			
		||||
        volume_data: a Tensor of size (D, H, W) corresponding to
 | 
			
		||||
            a 3D scalar field
 | 
			
		||||
    Returns:
 | 
			
		||||
        data: scalar value in the volume at the given point
 | 
			
		||||
    """
 | 
			
		||||
    x, y, z = point
 | 
			
		||||
    return volume_data[z][y][x]
 | 
			
		||||
							
								
								
									
										545
									
								
								pytorch3d/ops/marching_cubes_data.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										545
									
								
								pytorch3d/ops/marching_cubes_data.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,545 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
# A length 256 list which maps a cubeindex to a number
 | 
			
		||||
# with the intersected edges' bits set to 1.
 | 
			
		||||
# Each cubeindex corresponds to a given cube configuration, where
 | 
			
		||||
# it is composed of a bitstring where the 0th bit is flipped if vertex 0
 | 
			
		||||
# is below the isosurface (i.e. 0x01), for each of the 8 vertices.
 | 
			
		||||
EDGE_TABLE = [
 | 
			
		||||
    0x0,
 | 
			
		||||
    0x109,
 | 
			
		||||
    0x203,
 | 
			
		||||
    0x30A,
 | 
			
		||||
    0x406,
 | 
			
		||||
    0x50F,
 | 
			
		||||
    0x605,
 | 
			
		||||
    0x70C,
 | 
			
		||||
    0x80C,
 | 
			
		||||
    0x905,
 | 
			
		||||
    0xA0F,
 | 
			
		||||
    0xB06,
 | 
			
		||||
    0xC0A,
 | 
			
		||||
    0xD03,
 | 
			
		||||
    0xE09,
 | 
			
		||||
    0xF00,
 | 
			
		||||
    0x190,
 | 
			
		||||
    0x99,
 | 
			
		||||
    0x393,
 | 
			
		||||
    0x29A,
 | 
			
		||||
    0x596,
 | 
			
		||||
    0x49F,
 | 
			
		||||
    0x795,
 | 
			
		||||
    0x69C,
 | 
			
		||||
    0x99C,
 | 
			
		||||
    0x895,
 | 
			
		||||
    0xB9F,
 | 
			
		||||
    0xA96,
 | 
			
		||||
    0xD9A,
 | 
			
		||||
    0xC93,
 | 
			
		||||
    0xF99,
 | 
			
		||||
    0xE90,
 | 
			
		||||
    0x230,
 | 
			
		||||
    0x339,
 | 
			
		||||
    0x33,
 | 
			
		||||
    0x13A,
 | 
			
		||||
    0x636,
 | 
			
		||||
    0x73F,
 | 
			
		||||
    0x435,
 | 
			
		||||
    0x53C,
 | 
			
		||||
    0xA3C,
 | 
			
		||||
    0xB35,
 | 
			
		||||
    0x83F,
 | 
			
		||||
    0x936,
 | 
			
		||||
    0xE3A,
 | 
			
		||||
    0xF33,
 | 
			
		||||
    0xC39,
 | 
			
		||||
    0xD30,
 | 
			
		||||
    0x3A0,
 | 
			
		||||
    0x2A9,
 | 
			
		||||
    0x1A3,
 | 
			
		||||
    0xAA,
 | 
			
		||||
    0x7A6,
 | 
			
		||||
    0x6AF,
 | 
			
		||||
    0x5A5,
 | 
			
		||||
    0x4AC,
 | 
			
		||||
    0xBAC,
 | 
			
		||||
    0xAA5,
 | 
			
		||||
    0x9AF,
 | 
			
		||||
    0x8A6,
 | 
			
		||||
    0xFAA,
 | 
			
		||||
    0xEA3,
 | 
			
		||||
    0xDA9,
 | 
			
		||||
    0xCA0,
 | 
			
		||||
    0x460,
 | 
			
		||||
    0x569,
 | 
			
		||||
    0x663,
 | 
			
		||||
    0x76A,
 | 
			
		||||
    0x66,
 | 
			
		||||
    0x16F,
 | 
			
		||||
    0x265,
 | 
			
		||||
    0x36C,
 | 
			
		||||
    0xC6C,
 | 
			
		||||
    0xD65,
 | 
			
		||||
    0xE6F,
 | 
			
		||||
    0xF66,
 | 
			
		||||
    0x86A,
 | 
			
		||||
    0x963,
 | 
			
		||||
    0xA69,
 | 
			
		||||
    0xB60,
 | 
			
		||||
    0x5F0,
 | 
			
		||||
    0x4F9,
 | 
			
		||||
    0x7F3,
 | 
			
		||||
    0x6FA,
 | 
			
		||||
    0x1F6,
 | 
			
		||||
    0xFF,
 | 
			
		||||
    0x3F5,
 | 
			
		||||
    0x2FC,
 | 
			
		||||
    0xDFC,
 | 
			
		||||
    0xCF5,
 | 
			
		||||
    0xFFF,
 | 
			
		||||
    0xEF6,
 | 
			
		||||
    0x9FA,
 | 
			
		||||
    0x8F3,
 | 
			
		||||
    0xBF9,
 | 
			
		||||
    0xAF0,
 | 
			
		||||
    0x650,
 | 
			
		||||
    0x759,
 | 
			
		||||
    0x453,
 | 
			
		||||
    0x55A,
 | 
			
		||||
    0x256,
 | 
			
		||||
    0x35F,
 | 
			
		||||
    0x55,
 | 
			
		||||
    0x15C,
 | 
			
		||||
    0xE5C,
 | 
			
		||||
    0xF55,
 | 
			
		||||
    0xC5F,
 | 
			
		||||
    0xD56,
 | 
			
		||||
    0xA5A,
 | 
			
		||||
    0xB53,
 | 
			
		||||
    0x859,
 | 
			
		||||
    0x950,
 | 
			
		||||
    0x7C0,
 | 
			
		||||
    0x6C9,
 | 
			
		||||
    0x5C3,
 | 
			
		||||
    0x4CA,
 | 
			
		||||
    0x3C6,
 | 
			
		||||
    0x2CF,
 | 
			
		||||
    0x1C5,
 | 
			
		||||
    0xCC,
 | 
			
		||||
    0xFCC,
 | 
			
		||||
    0xEC5,
 | 
			
		||||
    0xDCF,
 | 
			
		||||
    0xCC6,
 | 
			
		||||
    0xBCA,
 | 
			
		||||
    0xAC3,
 | 
			
		||||
    0x9C9,
 | 
			
		||||
    0x8C0,
 | 
			
		||||
    0x8C0,
 | 
			
		||||
    0x9C9,
 | 
			
		||||
    0xAC3,
 | 
			
		||||
    0xBCA,
 | 
			
		||||
    0xCC6,
 | 
			
		||||
    0xDCF,
 | 
			
		||||
    0xEC5,
 | 
			
		||||
    0xFCC,
 | 
			
		||||
    0xCC,
 | 
			
		||||
    0x1C5,
 | 
			
		||||
    0x2CF,
 | 
			
		||||
    0x3C6,
 | 
			
		||||
    0x4CA,
 | 
			
		||||
    0x5C3,
 | 
			
		||||
    0x6C9,
 | 
			
		||||
    0x7C0,
 | 
			
		||||
    0x950,
 | 
			
		||||
    0x859,
 | 
			
		||||
    0xB53,
 | 
			
		||||
    0xA5A,
 | 
			
		||||
    0xD56,
 | 
			
		||||
    0xC5F,
 | 
			
		||||
    0xF55,
 | 
			
		||||
    0xE5C,
 | 
			
		||||
    0x15C,
 | 
			
		||||
    0x55,
 | 
			
		||||
    0x35F,
 | 
			
		||||
    0x256,
 | 
			
		||||
    0x55A,
 | 
			
		||||
    0x453,
 | 
			
		||||
    0x759,
 | 
			
		||||
    0x650,
 | 
			
		||||
    0xAF0,
 | 
			
		||||
    0xBF9,
 | 
			
		||||
    0x8F3,
 | 
			
		||||
    0x9FA,
 | 
			
		||||
    0xEF6,
 | 
			
		||||
    0xFFF,
 | 
			
		||||
    0xCF5,
 | 
			
		||||
    0xDFC,
 | 
			
		||||
    0x2FC,
 | 
			
		||||
    0x3F5,
 | 
			
		||||
    0xFF,
 | 
			
		||||
    0x1F6,
 | 
			
		||||
    0x6FA,
 | 
			
		||||
    0x7F3,
 | 
			
		||||
    0x4F9,
 | 
			
		||||
    0x5F0,
 | 
			
		||||
    0xB60,
 | 
			
		||||
    0xA69,
 | 
			
		||||
    0x963,
 | 
			
		||||
    0x86A,
 | 
			
		||||
    0xF66,
 | 
			
		||||
    0xE6F,
 | 
			
		||||
    0xD65,
 | 
			
		||||
    0xC6C,
 | 
			
		||||
    0x36C,
 | 
			
		||||
    0x265,
 | 
			
		||||
    0x16F,
 | 
			
		||||
    0x66,
 | 
			
		||||
    0x76A,
 | 
			
		||||
    0x663,
 | 
			
		||||
    0x569,
 | 
			
		||||
    0x460,
 | 
			
		||||
    0xCA0,
 | 
			
		||||
    0xDA9,
 | 
			
		||||
    0xEA3,
 | 
			
		||||
    0xFAA,
 | 
			
		||||
    0x8A6,
 | 
			
		||||
    0x9AF,
 | 
			
		||||
    0xAA5,
 | 
			
		||||
    0xBAC,
 | 
			
		||||
    0x4AC,
 | 
			
		||||
    0x5A5,
 | 
			
		||||
    0x6AF,
 | 
			
		||||
    0x7A6,
 | 
			
		||||
    0xAA,
 | 
			
		||||
    0x1A3,
 | 
			
		||||
    0x2A9,
 | 
			
		||||
    0x3A0,
 | 
			
		||||
    0xD30,
 | 
			
		||||
    0xC39,
 | 
			
		||||
    0xF33,
 | 
			
		||||
    0xE3A,
 | 
			
		||||
    0x936,
 | 
			
		||||
    0x83F,
 | 
			
		||||
    0xB35,
 | 
			
		||||
    0xA3C,
 | 
			
		||||
    0x53C,
 | 
			
		||||
    0x435,
 | 
			
		||||
    0x73F,
 | 
			
		||||
    0x636,
 | 
			
		||||
    0x13A,
 | 
			
		||||
    0x33,
 | 
			
		||||
    0x339,
 | 
			
		||||
    0x230,
 | 
			
		||||
    0xE90,
 | 
			
		||||
    0xF99,
 | 
			
		||||
    0xC93,
 | 
			
		||||
    0xD9A,
 | 
			
		||||
    0xA96,
 | 
			
		||||
    0xB9F,
 | 
			
		||||
    0x895,
 | 
			
		||||
    0x99C,
 | 
			
		||||
    0x69C,
 | 
			
		||||
    0x795,
 | 
			
		||||
    0x49F,
 | 
			
		||||
    0x596,
 | 
			
		||||
    0x29A,
 | 
			
		||||
    0x393,
 | 
			
		||||
    0x99,
 | 
			
		||||
    0x190,
 | 
			
		||||
    0xF00,
 | 
			
		||||
    0xE09,
 | 
			
		||||
    0xD03,
 | 
			
		||||
    0xC0A,
 | 
			
		||||
    0xB06,
 | 
			
		||||
    0xA0F,
 | 
			
		||||
    0x905,
 | 
			
		||||
    0x80C,
 | 
			
		||||
    0x70C,
 | 
			
		||||
    0x605,
 | 
			
		||||
    0x50F,
 | 
			
		||||
    0x406,
 | 
			
		||||
    0x30A,
 | 
			
		||||
    0x203,
 | 
			
		||||
    0x109,
 | 
			
		||||
    0x0,
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# Maps each edge (by index) to the corresponding cube vertices
 | 
			
		||||
EDGE_TO_VERTICES = [
 | 
			
		||||
    [0, 1],
 | 
			
		||||
    [1, 2],
 | 
			
		||||
    [3, 2],
 | 
			
		||||
    [0, 3],
 | 
			
		||||
    [4, 5],
 | 
			
		||||
    [5, 6],
 | 
			
		||||
    [7, 6],
 | 
			
		||||
    [4, 7],
 | 
			
		||||
    [0, 4],
 | 
			
		||||
    [1, 5],
 | 
			
		||||
    [2, 6],
 | 
			
		||||
    [3, 7],
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# A list of lists mapping a cube_index (a given configuration)
 | 
			
		||||
# to a list of faces corresponding to that configuration. Each face is represented
 | 
			
		||||
# by 3 consecutive numbers. A configuration will at most have 5 faces.
 | 
			
		||||
#
 | 
			
		||||
# Table taken from http://paulbourke.net/geometry/polygonise/
 | 
			
		||||
FACE_TABLE = [
 | 
			
		||||
    [],
 | 
			
		||||
    [0, 8, 3],
 | 
			
		||||
    [0, 1, 9],
 | 
			
		||||
    [1, 8, 3, 9, 8, 1],
 | 
			
		||||
    [1, 2, 10],
 | 
			
		||||
    [0, 8, 3, 1, 2, 10],
 | 
			
		||||
    [9, 2, 10, 0, 2, 9],
 | 
			
		||||
    [2, 8, 3, 2, 10, 8, 10, 9, 8],
 | 
			
		||||
    [3, 11, 2],
 | 
			
		||||
    [0, 11, 2, 8, 11, 0],
 | 
			
		||||
    [1, 9, 0, 2, 3, 11],
 | 
			
		||||
    [1, 11, 2, 1, 9, 11, 9, 8, 11],
 | 
			
		||||
    [3, 10, 1, 11, 10, 3],
 | 
			
		||||
    [0, 10, 1, 0, 8, 10, 8, 11, 10],
 | 
			
		||||
    [3, 9, 0, 3, 11, 9, 11, 10, 9],
 | 
			
		||||
    [9, 8, 10, 10, 8, 11],
 | 
			
		||||
    [4, 7, 8],
 | 
			
		||||
    [4, 3, 0, 7, 3, 4],
 | 
			
		||||
    [0, 1, 9, 8, 4, 7],
 | 
			
		||||
    [4, 1, 9, 4, 7, 1, 7, 3, 1],
 | 
			
		||||
    [1, 2, 10, 8, 4, 7],
 | 
			
		||||
    [3, 4, 7, 3, 0, 4, 1, 2, 10],
 | 
			
		||||
    [9, 2, 10, 9, 0, 2, 8, 4, 7],
 | 
			
		||||
    [2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4],
 | 
			
		||||
    [8, 4, 7, 3, 11, 2],
 | 
			
		||||
    [11, 4, 7, 11, 2, 4, 2, 0, 4],
 | 
			
		||||
    [9, 0, 1, 8, 4, 7, 2, 3, 11],
 | 
			
		||||
    [4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1],
 | 
			
		||||
    [3, 10, 1, 3, 11, 10, 7, 8, 4],
 | 
			
		||||
    [1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4],
 | 
			
		||||
    [4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3],
 | 
			
		||||
    [4, 7, 11, 4, 11, 9, 9, 11, 10],
 | 
			
		||||
    [9, 5, 4],
 | 
			
		||||
    [9, 5, 4, 0, 8, 3],
 | 
			
		||||
    [0, 5, 4, 1, 5, 0],
 | 
			
		||||
    [8, 5, 4, 8, 3, 5, 3, 1, 5],
 | 
			
		||||
    [1, 2, 10, 9, 5, 4],
 | 
			
		||||
    [3, 0, 8, 1, 2, 10, 4, 9, 5],
 | 
			
		||||
    [5, 2, 10, 5, 4, 2, 4, 0, 2],
 | 
			
		||||
    [2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8],
 | 
			
		||||
    [9, 5, 4, 2, 3, 11],
 | 
			
		||||
    [0, 11, 2, 0, 8, 11, 4, 9, 5],
 | 
			
		||||
    [0, 5, 4, 0, 1, 5, 2, 3, 11],
 | 
			
		||||
    [2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5],
 | 
			
		||||
    [10, 3, 11, 10, 1, 3, 9, 5, 4],
 | 
			
		||||
    [4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10],
 | 
			
		||||
    [5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3],
 | 
			
		||||
    [5, 4, 8, 5, 8, 10, 10, 8, 11],
 | 
			
		||||
    [9, 7, 8, 5, 7, 9],
 | 
			
		||||
    [9, 3, 0, 9, 5, 3, 5, 7, 3],
 | 
			
		||||
    [0, 7, 8, 0, 1, 7, 1, 5, 7],
 | 
			
		||||
    [1, 5, 3, 3, 5, 7],
 | 
			
		||||
    [9, 7, 8, 9, 5, 7, 10, 1, 2],
 | 
			
		||||
    [10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3],
 | 
			
		||||
    [8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2],
 | 
			
		||||
    [2, 10, 5, 2, 5, 3, 3, 5, 7],
 | 
			
		||||
    [7, 9, 5, 7, 8, 9, 3, 11, 2],
 | 
			
		||||
    [9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11],
 | 
			
		||||
    [2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7],
 | 
			
		||||
    [11, 2, 1, 11, 1, 7, 7, 1, 5],
 | 
			
		||||
    [9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11],
 | 
			
		||||
    [5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0],
 | 
			
		||||
    [11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0],
 | 
			
		||||
    [11, 10, 5, 7, 11, 5],
 | 
			
		||||
    [10, 6, 5],
 | 
			
		||||
    [0, 8, 3, 5, 10, 6],
 | 
			
		||||
    [9, 0, 1, 5, 10, 6],
 | 
			
		||||
    [1, 8, 3, 1, 9, 8, 5, 10, 6],
 | 
			
		||||
    [1, 6, 5, 2, 6, 1],
 | 
			
		||||
    [1, 6, 5, 1, 2, 6, 3, 0, 8],
 | 
			
		||||
    [9, 6, 5, 9, 0, 6, 0, 2, 6],
 | 
			
		||||
    [5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8],
 | 
			
		||||
    [2, 3, 11, 10, 6, 5],
 | 
			
		||||
    [11, 0, 8, 11, 2, 0, 10, 6, 5],
 | 
			
		||||
    [0, 1, 9, 2, 3, 11, 5, 10, 6],
 | 
			
		||||
    [5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11],
 | 
			
		||||
    [6, 3, 11, 6, 5, 3, 5, 1, 3],
 | 
			
		||||
    [0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6],
 | 
			
		||||
    [3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9],
 | 
			
		||||
    [6, 5, 9, 6, 9, 11, 11, 9, 8],
 | 
			
		||||
    [5, 10, 6, 4, 7, 8],
 | 
			
		||||
    [4, 3, 0, 4, 7, 3, 6, 5, 10],
 | 
			
		||||
    [1, 9, 0, 5, 10, 6, 8, 4, 7],
 | 
			
		||||
    [10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4],
 | 
			
		||||
    [6, 1, 2, 6, 5, 1, 4, 7, 8],
 | 
			
		||||
    [1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7],
 | 
			
		||||
    [8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6],
 | 
			
		||||
    [7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9],
 | 
			
		||||
    [3, 11, 2, 7, 8, 4, 10, 6, 5],
 | 
			
		||||
    [5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11],
 | 
			
		||||
    [0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6],
 | 
			
		||||
    [9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6],
 | 
			
		||||
    [8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6],
 | 
			
		||||
    [5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11],
 | 
			
		||||
    [0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7],
 | 
			
		||||
    [6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9],
 | 
			
		||||
    [10, 4, 9, 6, 4, 10],
 | 
			
		||||
    [4, 10, 6, 4, 9, 10, 0, 8, 3],
 | 
			
		||||
    [10, 0, 1, 10, 6, 0, 6, 4, 0],
 | 
			
		||||
    [8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10],
 | 
			
		||||
    [1, 4, 9, 1, 2, 4, 2, 6, 4],
 | 
			
		||||
    [3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4],
 | 
			
		||||
    [0, 2, 4, 4, 2, 6],
 | 
			
		||||
    [8, 3, 2, 8, 2, 4, 4, 2, 6],
 | 
			
		||||
    [10, 4, 9, 10, 6, 4, 11, 2, 3],
 | 
			
		||||
    [0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6],
 | 
			
		||||
    [3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10],
 | 
			
		||||
    [6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1],
 | 
			
		||||
    [9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3],
 | 
			
		||||
    [8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1],
 | 
			
		||||
    [3, 11, 6, 3, 6, 0, 0, 6, 4],
 | 
			
		||||
    [6, 4, 8, 11, 6, 8],
 | 
			
		||||
    [7, 10, 6, 7, 8, 10, 8, 9, 10],
 | 
			
		||||
    [0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10],
 | 
			
		||||
    [10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0],
 | 
			
		||||
    [10, 6, 7, 10, 7, 1, 1, 7, 3],
 | 
			
		||||
    [1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7],
 | 
			
		||||
    [2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9],
 | 
			
		||||
    [7, 8, 0, 7, 0, 6, 6, 0, 2],
 | 
			
		||||
    [7, 3, 2, 6, 7, 2],
 | 
			
		||||
    [2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7],
 | 
			
		||||
    [2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7],
 | 
			
		||||
    [1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11],
 | 
			
		||||
    [11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1],
 | 
			
		||||
    [8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6],
 | 
			
		||||
    [0, 9, 1, 11, 6, 7],
 | 
			
		||||
    [7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0],
 | 
			
		||||
    [7, 11, 6],
 | 
			
		||||
    [7, 6, 11],
 | 
			
		||||
    [3, 0, 8, 11, 7, 6],
 | 
			
		||||
    [0, 1, 9, 11, 7, 6],
 | 
			
		||||
    [8, 1, 9, 8, 3, 1, 11, 7, 6],
 | 
			
		||||
    [10, 1, 2, 6, 11, 7],
 | 
			
		||||
    [1, 2, 10, 3, 0, 8, 6, 11, 7],
 | 
			
		||||
    [2, 9, 0, 2, 10, 9, 6, 11, 7],
 | 
			
		||||
    [6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8],
 | 
			
		||||
    [7, 2, 3, 6, 2, 7],
 | 
			
		||||
    [7, 0, 8, 7, 6, 0, 6, 2, 0],
 | 
			
		||||
    [2, 7, 6, 2, 3, 7, 0, 1, 9],
 | 
			
		||||
    [1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6],
 | 
			
		||||
    [10, 7, 6, 10, 1, 7, 1, 3, 7],
 | 
			
		||||
    [10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8],
 | 
			
		||||
    [0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7],
 | 
			
		||||
    [7, 6, 10, 7, 10, 8, 8, 10, 9],
 | 
			
		||||
    [6, 8, 4, 11, 8, 6],
 | 
			
		||||
    [3, 6, 11, 3, 0, 6, 0, 4, 6],
 | 
			
		||||
    [8, 6, 11, 8, 4, 6, 9, 0, 1],
 | 
			
		||||
    [9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6],
 | 
			
		||||
    [6, 8, 4, 6, 11, 8, 2, 10, 1],
 | 
			
		||||
    [1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6],
 | 
			
		||||
    [4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9],
 | 
			
		||||
    [10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3],
 | 
			
		||||
    [8, 2, 3, 8, 4, 2, 4, 6, 2],
 | 
			
		||||
    [0, 4, 2, 4, 6, 2],
 | 
			
		||||
    [1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8],
 | 
			
		||||
    [1, 9, 4, 1, 4, 2, 2, 4, 6],
 | 
			
		||||
    [8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1],
 | 
			
		||||
    [10, 1, 0, 10, 0, 6, 6, 0, 4],
 | 
			
		||||
    [4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3],
 | 
			
		||||
    [10, 9, 4, 6, 10, 4],
 | 
			
		||||
    [4, 9, 5, 7, 6, 11],
 | 
			
		||||
    [0, 8, 3, 4, 9, 5, 11, 7, 6],
 | 
			
		||||
    [5, 0, 1, 5, 4, 0, 7, 6, 11],
 | 
			
		||||
    [11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5],
 | 
			
		||||
    [9, 5, 4, 10, 1, 2, 7, 6, 11],
 | 
			
		||||
    [6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5],
 | 
			
		||||
    [7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2],
 | 
			
		||||
    [3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6],
 | 
			
		||||
    [7, 2, 3, 7, 6, 2, 5, 4, 9],
 | 
			
		||||
    [9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7],
 | 
			
		||||
    [3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0],
 | 
			
		||||
    [6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8],
 | 
			
		||||
    [9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7],
 | 
			
		||||
    [1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4],
 | 
			
		||||
    [4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10],
 | 
			
		||||
    [7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10],
 | 
			
		||||
    [6, 9, 5, 6, 11, 9, 11, 8, 9],
 | 
			
		||||
    [3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5],
 | 
			
		||||
    [0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11],
 | 
			
		||||
    [6, 11, 3, 6, 3, 5, 5, 3, 1],
 | 
			
		||||
    [1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6],
 | 
			
		||||
    [0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10],
 | 
			
		||||
    [11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5],
 | 
			
		||||
    [6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3],
 | 
			
		||||
    [5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2],
 | 
			
		||||
    [9, 5, 6, 9, 6, 0, 0, 6, 2],
 | 
			
		||||
    [1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8],
 | 
			
		||||
    [1, 5, 6, 2, 1, 6],
 | 
			
		||||
    [1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6],
 | 
			
		||||
    [10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0],
 | 
			
		||||
    [0, 3, 8, 5, 6, 10],
 | 
			
		||||
    [10, 5, 6],
 | 
			
		||||
    [11, 5, 10, 7, 5, 11],
 | 
			
		||||
    [11, 5, 10, 11, 7, 5, 8, 3, 0],
 | 
			
		||||
    [5, 11, 7, 5, 10, 11, 1, 9, 0],
 | 
			
		||||
    [10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1],
 | 
			
		||||
    [11, 1, 2, 11, 7, 1, 7, 5, 1],
 | 
			
		||||
    [0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11],
 | 
			
		||||
    [9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7],
 | 
			
		||||
    [7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2],
 | 
			
		||||
    [2, 5, 10, 2, 3, 5, 3, 7, 5],
 | 
			
		||||
    [8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5],
 | 
			
		||||
    [9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2],
 | 
			
		||||
    [9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2],
 | 
			
		||||
    [1, 3, 5, 3, 7, 5],
 | 
			
		||||
    [0, 8, 7, 0, 7, 1, 1, 7, 5],
 | 
			
		||||
    [9, 0, 3, 9, 3, 5, 5, 3, 7],
 | 
			
		||||
    [9, 8, 7, 5, 9, 7],
 | 
			
		||||
    [5, 8, 4, 5, 10, 8, 10, 11, 8],
 | 
			
		||||
    [5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0],
 | 
			
		||||
    [0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5],
 | 
			
		||||
    [10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4],
 | 
			
		||||
    [2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8],
 | 
			
		||||
    [0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11],
 | 
			
		||||
    [0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5],
 | 
			
		||||
    [9, 4, 5, 2, 11, 3],
 | 
			
		||||
    [2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4],
 | 
			
		||||
    [5, 10, 2, 5, 2, 4, 4, 2, 0],
 | 
			
		||||
    [3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9],
 | 
			
		||||
    [5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2],
 | 
			
		||||
    [8, 4, 5, 8, 5, 3, 3, 5, 1],
 | 
			
		||||
    [0, 4, 5, 1, 0, 5],
 | 
			
		||||
    [8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5],
 | 
			
		||||
    [9, 4, 5],
 | 
			
		||||
    [4, 11, 7, 4, 9, 11, 9, 10, 11],
 | 
			
		||||
    [0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11],
 | 
			
		||||
    [1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11],
 | 
			
		||||
    [3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4],
 | 
			
		||||
    [4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2],
 | 
			
		||||
    [9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3],
 | 
			
		||||
    [11, 7, 4, 11, 4, 2, 2, 4, 0],
 | 
			
		||||
    [11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4],
 | 
			
		||||
    [2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9],
 | 
			
		||||
    [9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7],
 | 
			
		||||
    [3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10],
 | 
			
		||||
    [1, 10, 2, 8, 7, 4],
 | 
			
		||||
    [4, 9, 1, 4, 1, 7, 7, 1, 3],
 | 
			
		||||
    [4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1],
 | 
			
		||||
    [4, 0, 3, 7, 4, 3],
 | 
			
		||||
    [4, 8, 7],
 | 
			
		||||
    [9, 10, 8, 10, 11, 8],
 | 
			
		||||
    [3, 0, 9, 3, 9, 11, 11, 9, 10],
 | 
			
		||||
    [0, 1, 10, 0, 10, 8, 8, 10, 11],
 | 
			
		||||
    [3, 1, 10, 11, 3, 10],
 | 
			
		||||
    [1, 2, 11, 1, 11, 9, 9, 11, 8],
 | 
			
		||||
    [3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9],
 | 
			
		||||
    [0, 2, 11, 8, 0, 11],
 | 
			
		||||
    [3, 2, 11],
 | 
			
		||||
    [2, 3, 8, 2, 8, 10, 10, 8, 9],
 | 
			
		||||
    [9, 10, 2, 0, 9, 2],
 | 
			
		||||
    [2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8],
 | 
			
		||||
    [1, 10, 2],
 | 
			
		||||
    [1, 3, 8, 9, 1, 8],
 | 
			
		||||
    [0, 9, 1],
 | 
			
		||||
    [0, 3, 8],
 | 
			
		||||
    [],
 | 
			
		||||
]
 | 
			
		||||
@ -414,7 +414,7 @@ class Transform3d:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Translate(Transform3d):
 | 
			
		||||
    def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
 | 
			
		||||
    def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
 | 
			
		||||
        """
 | 
			
		||||
        Create a new Transform3d representing 3D translations.
 | 
			
		||||
 | 
			
		||||
@ -448,7 +448,7 @@ class Translate(Transform3d):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Scale(Transform3d):
 | 
			
		||||
    def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
 | 
			
		||||
    def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
 | 
			
		||||
        """
 | 
			
		||||
        A Transform3d representing a scaling operation, with different scale
 | 
			
		||||
        factors along each coordinate axis.
 | 
			
		||||
@ -489,7 +489,7 @@ class Scale(Transform3d):
 | 
			
		||||
 | 
			
		||||
class Rotate(Transform3d):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, R, dtype=torch.float32, device: str = "cpu", orthogonal_tol: float = 1e-5
 | 
			
		||||
        self, R, dtype=torch.float32, device="cpu", orthogonal_tol: float = 1e-5
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Create a new Transform3d representing 3D rotation using a rotation
 | 
			
		||||
@ -528,7 +528,7 @@ class RotateAxisAngle(Rotate):
 | 
			
		||||
        axis: str = "X",
 | 
			
		||||
        degrees: bool = True,
 | 
			
		||||
        dtype=torch.float64,
 | 
			
		||||
        device: str = "cpu",
 | 
			
		||||
        device="cpu",
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Create a new Transform3d representing 3D rotation about an axis
 | 
			
		||||
@ -635,7 +635,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
 | 
			
		||||
    return xyz
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _handle_angle_input(x, dtype, device: str, name: str):
 | 
			
		||||
def _handle_angle_input(x, dtype, device, name: str):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function for building a rotation function using angles.
 | 
			
		||||
    The output is always of shape (N,).
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										25
									
								
								tests/bm_marching_cubes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								tests/bm_marching_cubes.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,25 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from fvcore.common.benchmark import benchmark
 | 
			
		||||
from test_marching_cubes import TestMarchingCubes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bm_marching_cubes() -> None:
 | 
			
		||||
    kwargs_list = [
 | 
			
		||||
        {"batch_size": 1, "V": 5},
 | 
			
		||||
        {"batch_size": 1, "V": 10},
 | 
			
		||||
        {"batch_size": 1, "V": 20},
 | 
			
		||||
        {"batch_size": 1, "V": 40},
 | 
			
		||||
        {"batch_size": 5, "V": 5},
 | 
			
		||||
        {"batch_size": 20, "V": 20},
 | 
			
		||||
    ]
 | 
			
		||||
    benchmark(
 | 
			
		||||
        TestMarchingCubes.marching_cubes_with_init,
 | 
			
		||||
        "MARCHING_CUBES",
 | 
			
		||||
        kwargs_list,
 | 
			
		||||
        warmup_iters=1,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    bm_marching_cubes()
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_marching_cubes_data/double_ellipsoid.pickle
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_marching_cubes_data/double_ellipsoid.pickle
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_marching_cubes_data/sphere_level64.pickle
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_marching_cubes_data/sphere_level64.pickle
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										771
									
								
								tests/test_marching_cubes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										771
									
								
								tests/test_marching_cubes.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,771 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
import os
 | 
			
		||||
import pickle
 | 
			
		||||
import unittest
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin
 | 
			
		||||
from pytorch3d.ops.marching_cubes import marching_cubes_naive
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
USE_SCIKIT = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_to_local(verts, volume_dim):
 | 
			
		||||
    return (2 * verts) / (volume_dim - 1) - 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    # Test single cubes. Each case corresponds to the corresponding
 | 
			
		||||
    # cube vertex configuration in each case here (0-indexed):
 | 
			
		||||
    # https://en.wikipedia.org/wiki/Marching_cubes#/media/File:MarchingCubes.svg
 | 
			
		||||
 | 
			
		||||
    def test_empty_volume(self):  # case 0
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor([])
 | 
			
		||||
        expected_faces = torch.tensor([], dtype=torch.int64)
 | 
			
		||||
        self.assertClose(verts, expected_verts)
 | 
			
		||||
        self.assertClose(faces, expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case1(self):  # case 1
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 0, 0, 0] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.5, 0, 0],
 | 
			
		||||
                [0, 0, 0.5],
 | 
			
		||||
                [0, 0.5, 0],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_faces = torch.tensor([[1, 2, 0]])
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case2(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 0:2, 0, 0] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [1.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.0000, 0.0000, 0.5000],
 | 
			
		||||
                [1.0000, 0.5000, 0.0000],
 | 
			
		||||
                [0.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        expected_faces = torch.tensor([[1, 2, 0], [3, 2, 1]])
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case3(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 0, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 1, 1, 0] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.5000, 0.0000, 0.0000],
 | 
			
		||||
                [0.0000, 0.0000, 0.5000],
 | 
			
		||||
                [1.0000, 1.0000, 0.5000],
 | 
			
		||||
                [0.5000, 1.0000, 0.0000],
 | 
			
		||||
                [1.0000, 0.5000, 0.0000],
 | 
			
		||||
                [0.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        expected_faces = torch.tensor([[0, 1, 5], [4, 3, 2]])
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case4(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 1, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 1, 0, 1] = 0
 | 
			
		||||
        volume_data[0, 0, 0, 1] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.5000, 0.0000, 0.0000],
 | 
			
		||||
                [0.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.0000, 0.5000, 1.0000],
 | 
			
		||||
                [1.0000, 0.5000, 1.0000],
 | 
			
		||||
                [1.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        expected_faces = torch.tensor([[0, 2, 1], [0, 4, 2], [4, 3, 2]])
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case5(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 0:2, 0, 0:2] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.0000, 0.5000, 1.0000],
 | 
			
		||||
                [1.0000, 0.5000, 1.0000],
 | 
			
		||||
                [1.0000, 0.5000, 0.0000],
 | 
			
		||||
                [0.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_faces = torch.tensor([[1, 0, 2], [2, 0, 3]])
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case6(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 1, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 1, 0, 1] = 0
 | 
			
		||||
        volume_data[0, 0, 0, 1] = 0
 | 
			
		||||
        volume_data[0, 0, 1, 0] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.5000, 0.0000, 0.0000],
 | 
			
		||||
                [0.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.5000, 1.0000, 0.0000],
 | 
			
		||||
                [0.0000, 1.0000, 0.5000],
 | 
			
		||||
                [0.0000, 0.5000, 1.0000],
 | 
			
		||||
                [1.0000, 0.5000, 1.0000],
 | 
			
		||||
                [1.0000, 0.5000, 0.0000],
 | 
			
		||||
                [0.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        expected_faces = torch.tensor([[2, 7, 3], [0, 6, 1], [6, 4, 1], [6, 5, 4]])
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case7(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 0, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 1, 0, 1] = 0
 | 
			
		||||
        volume_data[0, 1, 1, 0] = 0
 | 
			
		||||
        volume_data[0, 0, 1, 1] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.5000, 0.0000, 1.0000],
 | 
			
		||||
                [1.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.5000, 0.0000, 0.0000],
 | 
			
		||||
                [0.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.5000, 1.0000, 1.0000],
 | 
			
		||||
                [1.0000, 1.0000, 0.5000],
 | 
			
		||||
                [0.5000, 1.0000, 0.0000],
 | 
			
		||||
                [0.0000, 1.0000, 0.5000],
 | 
			
		||||
                [0.0000, 0.5000, 1.0000],
 | 
			
		||||
                [1.0000, 0.5000, 1.0000],
 | 
			
		||||
                [1.0000, 0.5000, 0.0000],
 | 
			
		||||
                [0.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_faces = torch.tensor([[0, 1, 9], [4, 7, 8], [2, 3, 11], [5, 10, 6]])
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case8(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 0, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 0, 0, 1] = 0
 | 
			
		||||
        volume_data[0, 1, 0, 1] = 0
 | 
			
		||||
        volume_data[0, 0, 1, 1] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [1.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.5000, 0.0000, 0.0000],
 | 
			
		||||
                [0.5000, 1.0000, 1.0000],
 | 
			
		||||
                [0.0000, 1.0000, 0.5000],
 | 
			
		||||
                [1.0000, 0.5000, 1.0000],
 | 
			
		||||
                [0.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        expected_faces = torch.tensor([[2, 3, 5], [4, 2, 5], [4, 5, 1], [4, 1, 0]])
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case9(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 1, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 0, 0, 1] = 0
 | 
			
		||||
        volume_data[0, 1, 0, 1] = 0
 | 
			
		||||
        volume_data[0, 0, 1, 1] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.5000, 0.0000, 0.0000],
 | 
			
		||||
                [0.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.5000, 1.0000, 1.0000],
 | 
			
		||||
                [0.0000, 1.0000, 0.5000],
 | 
			
		||||
                [1.0000, 0.5000, 1.0000],
 | 
			
		||||
                [1.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        expected_faces = torch.tensor([[0, 5, 4], [0, 4, 3], [0, 3, 1], [3, 4, 2]])
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case10(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 0, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 1, 1, 1] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.5000, 0.0000, 0.0000],
 | 
			
		||||
                [0.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.5000, 1.0000, 1.0000],
 | 
			
		||||
                [1.0000, 1.0000, 0.5000],
 | 
			
		||||
                [1.0000, 0.5000, 1.0000],
 | 
			
		||||
                [0.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_faces = torch.tensor([[4, 3, 2], [0, 1, 5]])
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case11(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 0, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 1, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 1, 1, 1] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [1.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.5000, 1.0000, 1.0000],
 | 
			
		||||
                [1.0000, 1.0000, 0.5000],
 | 
			
		||||
                [1.0000, 0.5000, 1.0000],
 | 
			
		||||
                [1.0000, 0.5000, 0.0000],
 | 
			
		||||
                [0.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_faces = torch.tensor([[5, 1, 6], [5, 0, 1], [4, 3, 2]])
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case12(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 1, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 0, 1, 0] = 0
 | 
			
		||||
        volume_data[0, 1, 1, 1] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [1.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.5000, 0.0000, 0.0000],
 | 
			
		||||
                [0.5000, 1.0000, 1.0000],
 | 
			
		||||
                [1.0000, 1.0000, 0.5000],
 | 
			
		||||
                [0.5000, 1.0000, 0.0000],
 | 
			
		||||
                [0.0000, 1.0000, 0.5000],
 | 
			
		||||
                [1.0000, 0.5000, 1.0000],
 | 
			
		||||
                [1.0000, 0.5000, 0.0000],
 | 
			
		||||
                [0.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_faces = torch.tensor([[6, 3, 2], [7, 0, 1], [5, 4, 8]])
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case13(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 0, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 0, 1, 0] = 0
 | 
			
		||||
        volume_data[0, 1, 0, 1] = 0
 | 
			
		||||
        volume_data[0, 1, 1, 1] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.5000, 0.0000, 1.0000],
 | 
			
		||||
                [1.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.5000, 0.0000, 0.0000],
 | 
			
		||||
                [0.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.5000, 1.0000, 1.0000],
 | 
			
		||||
                [1.0000, 1.0000, 0.5000],
 | 
			
		||||
                [0.5000, 1.0000, 0.0000],
 | 
			
		||||
                [0.0000, 1.0000, 0.5000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_faces = torch.tensor([[3, 6, 2], [3, 7, 6], [1, 5, 0], [5, 4, 0]])
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
    def test_case14(self):
 | 
			
		||||
        volume_data = torch.ones(1, 2, 2, 2)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 0, 0, 0] = 0
 | 
			
		||||
        volume_data[0, 0, 0, 1] = 0
 | 
			
		||||
        volume_data[0, 1, 0, 1] = 0
 | 
			
		||||
        volume_data[0, 1, 1, 1] = 0
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [1.0000, 0.0000, 0.5000],
 | 
			
		||||
                [0.5000, 0.0000, 0.0000],
 | 
			
		||||
                [0.5000, 1.0000, 1.0000],
 | 
			
		||||
                [1.0000, 1.0000, 0.5000],
 | 
			
		||||
                [0.0000, 0.5000, 1.0000],
 | 
			
		||||
                [0.0000, 0.5000, 0.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_faces = torch.tensor([[1, 0, 3], [1, 3, 4], [1, 4, 5], [2, 4, 3]])
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 2)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_single_point(self):
 | 
			
		||||
        volume_data = torch.zeros(1, 3, 3, 3)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 1, 1, 1] = 1
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.5, 1, 1],
 | 
			
		||||
                [1, 1, 0.5],
 | 
			
		||||
                [1, 0.5, 1],
 | 
			
		||||
                [1, 1, 1.5],
 | 
			
		||||
                [1, 1.5, 1],
 | 
			
		||||
                [1.5, 1, 1],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        expected_faces = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [2, 0, 1],
 | 
			
		||||
                [2, 3, 0],
 | 
			
		||||
                [0, 4, 1],
 | 
			
		||||
                [3, 4, 0],
 | 
			
		||||
                [5, 2, 1],
 | 
			
		||||
                [3, 2, 5],
 | 
			
		||||
                [5, 1, 4],
 | 
			
		||||
                [3, 5, 4],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 3)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
        self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
 | 
			
		||||
 | 
			
		||||
    def test_cube(self):
 | 
			
		||||
        volume_data = torch.zeros(1, 5, 5, 5)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 1, 1, 1] = 1
 | 
			
		||||
        volume_data[0, 1, 1, 2] = 1
 | 
			
		||||
        volume_data[0, 2, 1, 1] = 1
 | 
			
		||||
        volume_data[0, 2, 1, 2] = 1
 | 
			
		||||
        volume_data[0, 1, 2, 1] = 1
 | 
			
		||||
        volume_data[0, 1, 2, 2] = 1
 | 
			
		||||
        volume_data[0, 2, 2, 1] = 1
 | 
			
		||||
        volume_data[0, 2, 2, 2] = 1
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.9000, 1.0000, 1.0000],
 | 
			
		||||
                [1.0000, 1.0000, 0.9000],
 | 
			
		||||
                [1.0000, 0.9000, 1.0000],
 | 
			
		||||
                [0.9000, 1.0000, 2.0000],
 | 
			
		||||
                [1.0000, 0.9000, 2.0000],
 | 
			
		||||
                [1.0000, 1.0000, 2.1000],
 | 
			
		||||
                [0.9000, 2.0000, 1.0000],
 | 
			
		||||
                [1.0000, 2.0000, 0.9000],
 | 
			
		||||
                [0.9000, 2.0000, 2.0000],
 | 
			
		||||
                [1.0000, 2.0000, 2.1000],
 | 
			
		||||
                [1.0000, 2.1000, 1.0000],
 | 
			
		||||
                [1.0000, 2.1000, 2.0000],
 | 
			
		||||
                [2.0000, 1.0000, 0.9000],
 | 
			
		||||
                [2.0000, 0.9000, 1.0000],
 | 
			
		||||
                [2.0000, 0.9000, 2.0000],
 | 
			
		||||
                [2.0000, 1.0000, 2.1000],
 | 
			
		||||
                [2.0000, 2.0000, 0.9000],
 | 
			
		||||
                [2.0000, 2.0000, 2.1000],
 | 
			
		||||
                [2.0000, 2.1000, 1.0000],
 | 
			
		||||
                [2.0000, 2.1000, 2.0000],
 | 
			
		||||
                [2.1000, 1.0000, 1.0000],
 | 
			
		||||
                [2.1000, 1.0000, 2.0000],
 | 
			
		||||
                [2.1000, 2.0000, 1.0000],
 | 
			
		||||
                [2.1000, 2.0000, 2.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_faces = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [2, 0, 1],
 | 
			
		||||
                [2, 4, 3],
 | 
			
		||||
                [0, 2, 3],
 | 
			
		||||
                [4, 5, 3],
 | 
			
		||||
                [0, 6, 7],
 | 
			
		||||
                [1, 0, 7],
 | 
			
		||||
                [3, 8, 0],
 | 
			
		||||
                [8, 6, 0],
 | 
			
		||||
                [5, 9, 8],
 | 
			
		||||
                [3, 5, 8],
 | 
			
		||||
                [6, 10, 7],
 | 
			
		||||
                [11, 10, 6],
 | 
			
		||||
                [8, 11, 6],
 | 
			
		||||
                [9, 11, 8],
 | 
			
		||||
                [13, 2, 1],
 | 
			
		||||
                [12, 13, 1],
 | 
			
		||||
                [14, 4, 13],
 | 
			
		||||
                [13, 4, 2],
 | 
			
		||||
                [4, 14, 15],
 | 
			
		||||
                [5, 4, 15],
 | 
			
		||||
                [12, 1, 16],
 | 
			
		||||
                [1, 7, 16],
 | 
			
		||||
                [15, 17, 5],
 | 
			
		||||
                [5, 17, 9],
 | 
			
		||||
                [16, 7, 10],
 | 
			
		||||
                [18, 16, 10],
 | 
			
		||||
                [19, 18, 11],
 | 
			
		||||
                [18, 10, 11],
 | 
			
		||||
                [9, 17, 19],
 | 
			
		||||
                [11, 9, 19],
 | 
			
		||||
                [20, 13, 12],
 | 
			
		||||
                [20, 21, 14],
 | 
			
		||||
                [13, 20, 14],
 | 
			
		||||
                [15, 14, 21],
 | 
			
		||||
                [22, 20, 12],
 | 
			
		||||
                [16, 22, 12],
 | 
			
		||||
                [21, 20, 23],
 | 
			
		||||
                [23, 20, 22],
 | 
			
		||||
                [17, 15, 21],
 | 
			
		||||
                [23, 17, 21],
 | 
			
		||||
                [22, 16, 18],
 | 
			
		||||
                [23, 22, 18],
 | 
			
		||||
                [19, 23, 18],
 | 
			
		||||
                [17, 23, 19],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 5)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        # Check all values are in the range [-1, 1]
 | 
			
		||||
        self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
 | 
			
		||||
 | 
			
		||||
    def test_cube_no_duplicate_verts(self):
 | 
			
		||||
        volume_data = torch.zeros(1, 5, 5, 5)  # (B, W, H, D)
 | 
			
		||||
        volume_data[0, 1, 1, 1] = 1
 | 
			
		||||
        volume_data[0, 1, 1, 2] = 1
 | 
			
		||||
        volume_data[0, 2, 1, 1] = 1
 | 
			
		||||
        volume_data[0, 2, 1, 2] = 1
 | 
			
		||||
        volume_data[0, 1, 2, 1] = 1
 | 
			
		||||
        volume_data[0, 1, 2, 2] = 1
 | 
			
		||||
        volume_data[0, 2, 2, 1] = 1
 | 
			
		||||
        volume_data[0, 2, 2, 2] = 1
 | 
			
		||||
        volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, 1, return_local_coords=False)
 | 
			
		||||
 | 
			
		||||
        expected_verts = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [1.0, 1.0, 1.0],
 | 
			
		||||
                [1.0, 1.0, 2.0],
 | 
			
		||||
                [1.0, 2.0, 1.0],
 | 
			
		||||
                [1.0, 2.0, 2.0],
 | 
			
		||||
                [2.0, 1.0, 1.0],
 | 
			
		||||
                [2.0, 1.0, 2.0],
 | 
			
		||||
                [2.0, 2.0, 1.0],
 | 
			
		||||
                [2.0, 2.0, 2.0],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_faces = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [1, 3, 0],
 | 
			
		||||
                [3, 2, 0],
 | 
			
		||||
                [5, 1, 4],
 | 
			
		||||
                [4, 1, 0],
 | 
			
		||||
                [4, 0, 6],
 | 
			
		||||
                [0, 2, 6],
 | 
			
		||||
                [5, 7, 1],
 | 
			
		||||
                [1, 7, 3],
 | 
			
		||||
                [7, 6, 3],
 | 
			
		||||
                [6, 2, 3],
 | 
			
		||||
                [5, 4, 7],
 | 
			
		||||
                [7, 4, 6],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(volume_data, 1, return_local_coords=True)
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 5)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        # Check all values are in the range [-1, 1]
 | 
			
		||||
        self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
 | 
			
		||||
 | 
			
		||||
    def test_sphere(self):
 | 
			
		||||
        # (B, W, H, D)
 | 
			
		||||
        volume = torch.Tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [
 | 
			
		||||
                    [(x - 10) ** 2 + (y - 10) ** 2 + (z - 10) ** 2 for z in range(20)]
 | 
			
		||||
                    for y in range(20)
 | 
			
		||||
                ]
 | 
			
		||||
                for x in range(20)
 | 
			
		||||
            ]
 | 
			
		||||
        ).unsqueeze(0)
 | 
			
		||||
        volume = volume.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
        verts, faces = marching_cubes_naive(
 | 
			
		||||
            volume, isolevel=64, return_local_coords=False
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        DATA_DIR = Path(__file__).resolve().parent / "data"
 | 
			
		||||
        data_filename = "test_marching_cubes_data/sphere_level64.pickle"
 | 
			
		||||
        filename = os.path.join(DATA_DIR, data_filename)
 | 
			
		||||
        with open(filename, "rb") as file:
 | 
			
		||||
            verts_and_faces = pickle.load(file)
 | 
			
		||||
        expected_verts = verts_and_faces["verts"].squeeze()
 | 
			
		||||
        expected_faces = verts_and_faces["faces"].squeeze()
 | 
			
		||||
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        verts, faces = marching_cubes_naive(
 | 
			
		||||
            volume, isolevel=64, return_local_coords=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        expected_verts = convert_to_local(expected_verts, 20)
 | 
			
		||||
        self.assertClose(verts[0], expected_verts)
 | 
			
		||||
        self.assertClose(faces[0], expected_faces)
 | 
			
		||||
 | 
			
		||||
        # Check all values are in the range [-1, 1]
 | 
			
		||||
        self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
 | 
			
		||||
 | 
			
		||||
    # Uses skimage.draw.ellipsoid
 | 
			
		||||
    def test_double_ellipsoid(self):
 | 
			
		||||
        if USE_SCIKIT:
 | 
			
		||||
            import numpy as np
 | 
			
		||||
            from skimage.draw import ellipsoid
 | 
			
		||||
 | 
			
		||||
            ellip_base = ellipsoid(6, 10, 16, levelset=True)
 | 
			
		||||
            ellip_double = np.concatenate(
 | 
			
		||||
                (ellip_base[:-1, ...], ellip_base[2:, ...]), axis=0
 | 
			
		||||
            )
 | 
			
		||||
            volume = torch.Tensor(ellip_double).unsqueeze(0)
 | 
			
		||||
            volume = volume.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
            verts, faces = marching_cubes_naive(volume, isolevel=0.001)
 | 
			
		||||
 | 
			
		||||
            DATA_DIR = Path(__file__).resolve().parent / "data"
 | 
			
		||||
            data_filename = "test_marching_cubes_data/double_ellipsoid.pickle"
 | 
			
		||||
            filename = os.path.join(DATA_DIR, data_filename)
 | 
			
		||||
            with open(filename, "rb") as file:
 | 
			
		||||
                verts_and_faces = pickle.load(file)
 | 
			
		||||
            expected_verts = verts_and_faces["verts"]
 | 
			
		||||
            expected_faces = verts_and_faces["faces"]
 | 
			
		||||
 | 
			
		||||
            self.assertClose(verts[0], expected_verts[0])
 | 
			
		||||
            self.assertClose(faces[0], expected_faces[0])
 | 
			
		||||
 | 
			
		||||
    def test_cube_surface_area(self):
 | 
			
		||||
        if USE_SCIKIT:
 | 
			
		||||
            from skimage.measure import marching_cubes_classic, mesh_surface_area
 | 
			
		||||
 | 
			
		||||
            volume_data = torch.zeros(1, 5, 5, 5)
 | 
			
		||||
            volume_data[0, 1, 1, 1] = 1
 | 
			
		||||
            volume_data[0, 1, 1, 2] = 1
 | 
			
		||||
            volume_data[0, 2, 1, 1] = 1
 | 
			
		||||
            volume_data[0, 2, 1, 2] = 1
 | 
			
		||||
            volume_data[0, 1, 2, 1] = 1
 | 
			
		||||
            volume_data[0, 1, 2, 2] = 1
 | 
			
		||||
            volume_data[0, 2, 2, 1] = 1
 | 
			
		||||
            volume_data[0, 2, 2, 2] = 1
 | 
			
		||||
            volume_data = volume_data.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
            verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
            verts_sci, faces_sci = marching_cubes_classic(volume_data[0])
 | 
			
		||||
 | 
			
		||||
            surf = mesh_surface_area(verts[0], faces[0])
 | 
			
		||||
            surf_sci = mesh_surface_area(verts_sci, faces_sci)
 | 
			
		||||
 | 
			
		||||
            self.assertClose(surf, surf_sci)
 | 
			
		||||
 | 
			
		||||
    def test_sphere_surface_area(self):
 | 
			
		||||
        if USE_SCIKIT:
 | 
			
		||||
            from skimage.measure import marching_cubes_classic, mesh_surface_area
 | 
			
		||||
 | 
			
		||||
            # (B, W, H, D)
 | 
			
		||||
            volume = torch.Tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [
 | 
			
		||||
                        [
 | 
			
		||||
                            (x - 10) ** 2 + (y - 10) ** 2 + (z - 10) ** 2
 | 
			
		||||
                            for z in range(20)
 | 
			
		||||
                        ]
 | 
			
		||||
                        for y in range(20)
 | 
			
		||||
                    ]
 | 
			
		||||
                    for x in range(20)
 | 
			
		||||
                ]
 | 
			
		||||
            ).unsqueeze(0)
 | 
			
		||||
            volume = volume.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
            verts, faces = marching_cubes_naive(volume, isolevel=64)
 | 
			
		||||
            verts_sci, faces_sci = marching_cubes_classic(volume[0], level=64)
 | 
			
		||||
 | 
			
		||||
            surf = mesh_surface_area(verts[0], faces[0])
 | 
			
		||||
            surf_sci = mesh_surface_area(verts_sci, faces_sci)
 | 
			
		||||
 | 
			
		||||
            self.assertClose(surf, surf_sci)
 | 
			
		||||
 | 
			
		||||
    def test_double_ellipsoid_surface_area(self):
 | 
			
		||||
        if USE_SCIKIT:
 | 
			
		||||
            import numpy as np
 | 
			
		||||
            from skimage.draw import ellipsoid
 | 
			
		||||
            from skimage.measure import marching_cubes_classic, mesh_surface_area
 | 
			
		||||
 | 
			
		||||
            ellip_base = ellipsoid(6, 10, 16, levelset=True)
 | 
			
		||||
            ellip_double = np.concatenate(
 | 
			
		||||
                (ellip_base[:-1, ...], ellip_base[2:, ...]), axis=0
 | 
			
		||||
            )
 | 
			
		||||
            volume = torch.Tensor(ellip_double).unsqueeze(0)
 | 
			
		||||
            volume = volume.permute(0, 3, 2, 1)  # (B, D, H, W)
 | 
			
		||||
            verts, faces = marching_cubes_naive(volume, isolevel=0)
 | 
			
		||||
            verts_sci, faces_sci = marching_cubes_classic(volume[0], level=0)
 | 
			
		||||
 | 
			
		||||
            surf = mesh_surface_area(verts[0], faces[0])
 | 
			
		||||
            surf_sci = mesh_surface_area(verts_sci, faces_sci)
 | 
			
		||||
 | 
			
		||||
            self.assertClose(surf, surf_sci)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def marching_cubes_with_init(batch_size: int, V: int):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        volume_data = torch.rand(
 | 
			
		||||
            (batch_size, V, V, V), dtype=torch.float32, device=device
 | 
			
		||||
        )
 | 
			
		||||
        torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
        def convert():
 | 
			
		||||
            marching_cubes_naive(volume_data, return_local_coords=False)
 | 
			
		||||
            torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
        return convert
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user