mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Classic Marching Cubes algorithm implementation
Summary: Defines a function to run marching cubes algorithm on a single or batch of 3D scalar fields. Returns a mesh's faces and vertices. UPDATES (12/18) - Input data is now specified as a (B, D, H, W) tensor as opposed to a (B, W, H, D) tensor. This will now be compatible with the Volumes datastructure. - Add an option to return output vertices in local coordinates instead of world coordinates. Also added a small fix to remove the dype for device in Transforms3D - if passing in a torch.device instead of str it causes a pyre error. Reviewed By: jcjohnson Differential Revision: D24599019 fbshipit-source-id: 90554a200319fed8736a12371cc349e7108aacd0
This commit is contained in:
parent
9c6b58c5ad
commit
ebac66daeb
347
pytorch3d/ops/marching_cubes.py
Normal file
347
pytorch3d/ops/marching_cubes.py
Normal file
@ -0,0 +1,347 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.ops.marching_cubes_data import EDGE_TABLE, EDGE_TO_VERTICES, FACE_TABLE
|
||||
from pytorch3d.transforms import Translate
|
||||
|
||||
|
||||
EPS = 0.00001
|
||||
|
||||
|
||||
class Cube:
|
||||
def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1):
|
||||
"""
|
||||
Initializes a cube given the bottom front left vertex coordinate
|
||||
and the cube spacing
|
||||
|
||||
Edge and vertex convention:
|
||||
|
||||
v4_______e4____________v5
|
||||
/| /|
|
||||
/ | / |
|
||||
e7/ | e5/ |
|
||||
/___|______e6_________/ |
|
||||
v7| | |v6 |e9
|
||||
| | | |
|
||||
| |e8 |e10|
|
||||
e11| | | |
|
||||
| |_________________|___|
|
||||
| / v0 e0 | /v1
|
||||
| / | /
|
||||
| /e3 | /e1
|
||||
|/_____________________|/
|
||||
v3 e2 v2
|
||||
|
||||
Args:
|
||||
bfl_vertex: a tuple of size 3 corresponding to the bottom front left vertex
|
||||
of the cube in (x, y, z) format
|
||||
spacing: the length of each edge of the cube
|
||||
"""
|
||||
# match corner orders to algorithm convention
|
||||
if len(bfl_vertex) != 3:
|
||||
msg = "The vertex {} is size {} instead of size 3".format(
|
||||
bfl_vertex, len(bfl_vertex)
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
x, y, z = bfl_vertex
|
||||
self.vertices = torch.tensor(
|
||||
[
|
||||
[x, y, z + spacing],
|
||||
[x + spacing, y, z + spacing],
|
||||
[x + spacing, y, z],
|
||||
[x, y, z],
|
||||
[x, y + spacing, z + spacing],
|
||||
[x + spacing, y + spacing, z + spacing],
|
||||
[x + spacing, y + spacing, z],
|
||||
[x, y + spacing, z],
|
||||
]
|
||||
)
|
||||
|
||||
def get_index(self, volume_data: torch.Tensor, isolevel: float) -> int:
|
||||
"""
|
||||
Calculates the cube_index in the range 0-255 to index
|
||||
into EDGE_TABLE and FACE_TABLE
|
||||
Args:
|
||||
volume_data: the 3D scalar data
|
||||
isolevel: the isosurface value used as a threshold
|
||||
for determining whether a point is inside/outside
|
||||
the volume
|
||||
"""
|
||||
cube_index = 0
|
||||
bit = 1
|
||||
for index in range(len(self.vertices)):
|
||||
vertex = self.vertices[index]
|
||||
value = _get_value(vertex, volume_data)
|
||||
if value < isolevel:
|
||||
cube_index |= bit
|
||||
bit *= 2
|
||||
return cube_index
|
||||
|
||||
|
||||
def marching_cubes_naive(
|
||||
volume_data_batch: torch.Tensor,
|
||||
isolevel: Optional[float] = None,
|
||||
spacing: int = 1,
|
||||
return_local_coords: bool = True,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
"""
|
||||
Runs the classic marching cubes algorithm, iterating over
|
||||
the coordinates of the volume_data and using a given isolevel
|
||||
for determining intersected edges of cubes of size `spacing`.
|
||||
Returns vertices and faces of the obtained mesh.
|
||||
This operation is non-differentiable.
|
||||
|
||||
This is a naive implementation, and is not optimized for efficiency.
|
||||
|
||||
Args:
|
||||
volume_data_batch: a Tensor of size (N, D, H, W) corresponding to
|
||||
a batch of 3D scalar fields
|
||||
isolevel: the isosurface value to use as the threshold to determine
|
||||
whether points are within a volume. If None, then the average of the
|
||||
maximum and minimum value of the scalar field will be used.
|
||||
spacing: an integer specifying the cube size to use
|
||||
return_local_coords: bool. If True the output vertices will be in local coordinates in
|
||||
the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range
|
||||
[0, W-1] x [0, H-1] x [0, D-1]
|
||||
Returns:
|
||||
verts: [(V_0, 3), (V_1, 3), ...] List of N FloatTensors of vertices.
|
||||
faces: [(F_0, 3), (F_1, 3), ...] List of N LongTensors of faces.
|
||||
"""
|
||||
volume_data_batch = volume_data_batch.detach().cpu()
|
||||
batched_verts, batched_faces = [], []
|
||||
D, H, W = volume_data_batch.shape[1:]
|
||||
# pyre-ignore [16]
|
||||
volume_size_xyz = volume_data_batch.new_tensor([W, H, D])[None]
|
||||
|
||||
if return_local_coords:
|
||||
# Convert from local coordinates in the range [-1, 1] range to
|
||||
# world coordinates in the range [0, D-1], [0, H-1], [0, W-1]
|
||||
local_to_world_transform = Translate(
|
||||
x=+1.0, y=+1.0, z=+1.0, device=volume_data_batch.device
|
||||
).scale((volume_size_xyz - 1) * spacing * 0.5)
|
||||
# Perform the inverse to go from world to local
|
||||
world_to_local_transform = local_to_world_transform.inverse()
|
||||
|
||||
for i in range(len(volume_data_batch)):
|
||||
volume_data = volume_data_batch[i]
|
||||
curr_isolevel = (
|
||||
((volume_data.max() + volume_data.min()) / 2).item()
|
||||
if isolevel is None
|
||||
else isolevel
|
||||
)
|
||||
edge_vertices_to_index = {}
|
||||
vertex_coords_to_index = {}
|
||||
verts, faces = [], []
|
||||
# Use length - spacing for the bounds since we are using
|
||||
# cubes of size spacing, with the lowest x,y,z values
|
||||
# (bottom front left)
|
||||
for x in range(0, W - spacing, spacing):
|
||||
for y in range(0, H - spacing, spacing):
|
||||
for z in range(0, D - spacing, spacing):
|
||||
cube = Cube((x, y, z), spacing)
|
||||
new_verts, new_faces = polygonise(
|
||||
cube,
|
||||
curr_isolevel,
|
||||
volume_data,
|
||||
edge_vertices_to_index,
|
||||
vertex_coords_to_index,
|
||||
)
|
||||
verts.extend(new_verts)
|
||||
faces.extend(new_faces)
|
||||
if len(faces) > 0 and len(verts) > 0:
|
||||
verts = torch.tensor(verts, dtype=torch.float32)
|
||||
# Convert vertices from world to local coords
|
||||
if return_local_coords:
|
||||
verts = world_to_local_transform.transform_points(verts[None, ...])
|
||||
verts = verts.squeeze()
|
||||
batched_verts.append(verts)
|
||||
batched_faces.append(torch.tensor(faces, dtype=torch.int64))
|
||||
return batched_verts, batched_faces
|
||||
|
||||
|
||||
def polygonise(
|
||||
cube: Cube,
|
||||
isolevel: float,
|
||||
volume_data: torch.Tensor,
|
||||
edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],
|
||||
vertex_coords_to_index: Dict[Tuple[float, float, float], int],
|
||||
) -> Tuple[list, list]:
|
||||
"""
|
||||
Runs the classic marching cubes algorithm for one Cube in the volume.
|
||||
Returns the vertices and faces for the given cube.
|
||||
|
||||
Args:
|
||||
cube: a Cube indicating the cube being examined for edges that intersect
|
||||
the volume data.
|
||||
isolevel: the isosurface value to use as the threshold to determine
|
||||
whether points are within a volume.
|
||||
volume_data: a Tensor of shape (D, H, W) corresponding to
|
||||
a 3D scalar field
|
||||
edge_vertices_to_index: A dictionary which maps an edge's two coordinates
|
||||
to the index of its interpolated point, if that interpolated point
|
||||
has already been used by a previous point
|
||||
vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding
|
||||
index of that vertex, if that point has already been marked as a vertex.
|
||||
Returns:
|
||||
verts: List of triangle vertices for the given cube in the volume
|
||||
faces: List of triangle faces for the given cube in the volume
|
||||
"""
|
||||
num_existing_verts = max(edge_vertices_to_index.values(), default=-1) + 1
|
||||
verts, faces = [], []
|
||||
cube_index = cube.get_index(volume_data, isolevel)
|
||||
edges = EDGE_TABLE[cube_index]
|
||||
edge_indices = _get_edge_indices(edges)
|
||||
if len(edge_indices) == 0:
|
||||
return [], []
|
||||
|
||||
new_verts, edge_index_to_point_index = _calculate_interp_vertices(
|
||||
edge_indices,
|
||||
volume_data,
|
||||
cube,
|
||||
isolevel,
|
||||
edge_vertices_to_index,
|
||||
vertex_coords_to_index,
|
||||
num_existing_verts,
|
||||
)
|
||||
|
||||
# Create faces
|
||||
face_triangles = FACE_TABLE[cube_index]
|
||||
for i in range(0, len(face_triangles), 3):
|
||||
tri1 = edge_index_to_point_index[face_triangles[i]]
|
||||
tri2 = edge_index_to_point_index[face_triangles[i + 1]]
|
||||
tri3 = edge_index_to_point_index[face_triangles[i + 2]]
|
||||
if tri1 != tri2 and tri2 != tri3 and tri1 != tri3:
|
||||
faces.append([tri1, tri2, tri3])
|
||||
|
||||
verts += new_verts
|
||||
return verts, faces
|
||||
|
||||
|
||||
def _get_edge_indices(edges: int) -> List[int]:
|
||||
"""
|
||||
Finds which edge numbers are intersected given the bit representation
|
||||
detailed in marching_cubes_data.EDGE_TABLE.
|
||||
|
||||
Args:
|
||||
edges: an integer corresponding to the value at cube_index
|
||||
from the EDGE_TABLE in marching_cubes_data.py
|
||||
|
||||
Returns:
|
||||
edge_indices: A list of edge indices
|
||||
"""
|
||||
if edges == 0:
|
||||
return []
|
||||
|
||||
edge_indices = []
|
||||
for i in range(12):
|
||||
if edges & (2 ** i):
|
||||
edge_indices.append(i)
|
||||
return edge_indices
|
||||
|
||||
|
||||
def _calculate_interp_vertices(
|
||||
edge_indices: List[int],
|
||||
volume_data: torch.Tensor,
|
||||
cube: Cube,
|
||||
isolevel: float,
|
||||
edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],
|
||||
vertex_coords_to_index: Dict[Tuple[float, float, float], int],
|
||||
num_existing_verts: int,
|
||||
) -> Tuple[List, Dict[int, int]]:
|
||||
"""
|
||||
Finds the interpolated vertices for the intersected edges, either referencing
|
||||
previous calculations or newly calculating and storing the new interpolated
|
||||
points.
|
||||
|
||||
Args:
|
||||
edge_indices: the numbers of the edges which are intersected. See the
|
||||
Cube class for more detail on the edge numbering convention.
|
||||
volume_data: a Tensor of size (D, H, W) corresponding to
|
||||
a 3D scalar field
|
||||
cube: a Cube indicating the cube being examined for edges that intersect
|
||||
the volume
|
||||
isolevel: the isosurface value to use as the threshold to determine
|
||||
whether points are within a volume.
|
||||
edge_vertices_to_index: A dictionary which maps an edge's two coordinates
|
||||
to the index of its interpolated point, if that interpolated point
|
||||
has already been used by a previous point
|
||||
vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding
|
||||
index of that vertex, if that point has already been marked as a vertex.
|
||||
num_existing_verts: the number of vertices that have been found in previous
|
||||
calls to polygonise for the given volume_data in the above function, marching_cubes.
|
||||
This is equal to the 1 + the maximum value in edge_vertices_to_index.
|
||||
Returns:
|
||||
interp_points: a list of new interpolated points
|
||||
edge_index_to_point_index: a dictionary mapping an edge number to the index in the
|
||||
marching cubes' vertices list of the interpolated point on that edge. To be precise,
|
||||
it refers to the index within the vertices list after interp_points
|
||||
has been appended to the verts list constructed in the marching_cubes_naive
|
||||
function.
|
||||
"""
|
||||
interp_points = []
|
||||
edge_index_to_point_index = {}
|
||||
for edge_index in edge_indices:
|
||||
v1, v2 = EDGE_TO_VERTICES[edge_index]
|
||||
point1, point2 = cube.vertices[v1], cube.vertices[v2]
|
||||
p_tuple1, p_tuple2 = tuple(point1.tolist()), tuple(point2.tolist())
|
||||
if (p_tuple1, p_tuple2) in edge_vertices_to_index:
|
||||
edge_index_to_point_index[edge_index] = edge_vertices_to_index[
|
||||
(p_tuple1, p_tuple2)
|
||||
]
|
||||
else:
|
||||
val1, val2 = _get_value(point1, volume_data), _get_value(
|
||||
point2, volume_data
|
||||
)
|
||||
|
||||
point = None
|
||||
if abs(isolevel - val1) < EPS:
|
||||
point = point1
|
||||
|
||||
if abs(isolevel - val2) < EPS:
|
||||
point = point2
|
||||
|
||||
if abs(val1 - val2) < EPS:
|
||||
point = point1
|
||||
|
||||
if point is None:
|
||||
mu = (isolevel - val1) / (val2 - val1)
|
||||
x1, y1, z1 = point1
|
||||
x2, y2, z2 = point2
|
||||
x = x1 + mu * (x2 - x1)
|
||||
y = y1 + mu * (y2 - y1)
|
||||
z = z1 + mu * (z2 - z1)
|
||||
else:
|
||||
x, y, z = point
|
||||
|
||||
x, y, z = x.item(), y.item(), z.item() # for dictionary keys
|
||||
|
||||
vert_index = None
|
||||
if (x, y, z) in vertex_coords_to_index:
|
||||
vert_index = vertex_coords_to_index[(x, y, z)]
|
||||
else:
|
||||
vert_index = num_existing_verts + len(interp_points)
|
||||
interp_points.append([x, y, z])
|
||||
vertex_coords_to_index[(x, y, z)] = vert_index
|
||||
|
||||
edge_vertices_to_index[(p_tuple1, p_tuple2)] = vert_index
|
||||
edge_index_to_point_index[edge_index] = vert_index
|
||||
|
||||
return interp_points, edge_index_to_point_index
|
||||
|
||||
|
||||
def _get_value(point: Tuple[int, int, int], volume_data: torch.Tensor) -> float:
|
||||
"""
|
||||
Gets the value at a given coordinate point in the scalar field.
|
||||
|
||||
Args:
|
||||
point: data of shape (3) corresponding to an xyz coordinate.
|
||||
volume_data: a Tensor of size (D, H, W) corresponding to
|
||||
a 3D scalar field
|
||||
Returns:
|
||||
data: scalar value in the volume at the given point
|
||||
"""
|
||||
x, y, z = point
|
||||
return volume_data[z][y][x]
|
545
pytorch3d/ops/marching_cubes_data.py
Normal file
545
pytorch3d/ops/marching_cubes_data.py
Normal file
@ -0,0 +1,545 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
# A length 256 list which maps a cubeindex to a number
|
||||
# with the intersected edges' bits set to 1.
|
||||
# Each cubeindex corresponds to a given cube configuration, where
|
||||
# it is composed of a bitstring where the 0th bit is flipped if vertex 0
|
||||
# is below the isosurface (i.e. 0x01), for each of the 8 vertices.
|
||||
EDGE_TABLE = [
|
||||
0x0,
|
||||
0x109,
|
||||
0x203,
|
||||
0x30A,
|
||||
0x406,
|
||||
0x50F,
|
||||
0x605,
|
||||
0x70C,
|
||||
0x80C,
|
||||
0x905,
|
||||
0xA0F,
|
||||
0xB06,
|
||||
0xC0A,
|
||||
0xD03,
|
||||
0xE09,
|
||||
0xF00,
|
||||
0x190,
|
||||
0x99,
|
||||
0x393,
|
||||
0x29A,
|
||||
0x596,
|
||||
0x49F,
|
||||
0x795,
|
||||
0x69C,
|
||||
0x99C,
|
||||
0x895,
|
||||
0xB9F,
|
||||
0xA96,
|
||||
0xD9A,
|
||||
0xC93,
|
||||
0xF99,
|
||||
0xE90,
|
||||
0x230,
|
||||
0x339,
|
||||
0x33,
|
||||
0x13A,
|
||||
0x636,
|
||||
0x73F,
|
||||
0x435,
|
||||
0x53C,
|
||||
0xA3C,
|
||||
0xB35,
|
||||
0x83F,
|
||||
0x936,
|
||||
0xE3A,
|
||||
0xF33,
|
||||
0xC39,
|
||||
0xD30,
|
||||
0x3A0,
|
||||
0x2A9,
|
||||
0x1A3,
|
||||
0xAA,
|
||||
0x7A6,
|
||||
0x6AF,
|
||||
0x5A5,
|
||||
0x4AC,
|
||||
0xBAC,
|
||||
0xAA5,
|
||||
0x9AF,
|
||||
0x8A6,
|
||||
0xFAA,
|
||||
0xEA3,
|
||||
0xDA9,
|
||||
0xCA0,
|
||||
0x460,
|
||||
0x569,
|
||||
0x663,
|
||||
0x76A,
|
||||
0x66,
|
||||
0x16F,
|
||||
0x265,
|
||||
0x36C,
|
||||
0xC6C,
|
||||
0xD65,
|
||||
0xE6F,
|
||||
0xF66,
|
||||
0x86A,
|
||||
0x963,
|
||||
0xA69,
|
||||
0xB60,
|
||||
0x5F0,
|
||||
0x4F9,
|
||||
0x7F3,
|
||||
0x6FA,
|
||||
0x1F6,
|
||||
0xFF,
|
||||
0x3F5,
|
||||
0x2FC,
|
||||
0xDFC,
|
||||
0xCF5,
|
||||
0xFFF,
|
||||
0xEF6,
|
||||
0x9FA,
|
||||
0x8F3,
|
||||
0xBF9,
|
||||
0xAF0,
|
||||
0x650,
|
||||
0x759,
|
||||
0x453,
|
||||
0x55A,
|
||||
0x256,
|
||||
0x35F,
|
||||
0x55,
|
||||
0x15C,
|
||||
0xE5C,
|
||||
0xF55,
|
||||
0xC5F,
|
||||
0xD56,
|
||||
0xA5A,
|
||||
0xB53,
|
||||
0x859,
|
||||
0x950,
|
||||
0x7C0,
|
||||
0x6C9,
|
||||
0x5C3,
|
||||
0x4CA,
|
||||
0x3C6,
|
||||
0x2CF,
|
||||
0x1C5,
|
||||
0xCC,
|
||||
0xFCC,
|
||||
0xEC5,
|
||||
0xDCF,
|
||||
0xCC6,
|
||||
0xBCA,
|
||||
0xAC3,
|
||||
0x9C9,
|
||||
0x8C0,
|
||||
0x8C0,
|
||||
0x9C9,
|
||||
0xAC3,
|
||||
0xBCA,
|
||||
0xCC6,
|
||||
0xDCF,
|
||||
0xEC5,
|
||||
0xFCC,
|
||||
0xCC,
|
||||
0x1C5,
|
||||
0x2CF,
|
||||
0x3C6,
|
||||
0x4CA,
|
||||
0x5C3,
|
||||
0x6C9,
|
||||
0x7C0,
|
||||
0x950,
|
||||
0x859,
|
||||
0xB53,
|
||||
0xA5A,
|
||||
0xD56,
|
||||
0xC5F,
|
||||
0xF55,
|
||||
0xE5C,
|
||||
0x15C,
|
||||
0x55,
|
||||
0x35F,
|
||||
0x256,
|
||||
0x55A,
|
||||
0x453,
|
||||
0x759,
|
||||
0x650,
|
||||
0xAF0,
|
||||
0xBF9,
|
||||
0x8F3,
|
||||
0x9FA,
|
||||
0xEF6,
|
||||
0xFFF,
|
||||
0xCF5,
|
||||
0xDFC,
|
||||
0x2FC,
|
||||
0x3F5,
|
||||
0xFF,
|
||||
0x1F6,
|
||||
0x6FA,
|
||||
0x7F3,
|
||||
0x4F9,
|
||||
0x5F0,
|
||||
0xB60,
|
||||
0xA69,
|
||||
0x963,
|
||||
0x86A,
|
||||
0xF66,
|
||||
0xE6F,
|
||||
0xD65,
|
||||
0xC6C,
|
||||
0x36C,
|
||||
0x265,
|
||||
0x16F,
|
||||
0x66,
|
||||
0x76A,
|
||||
0x663,
|
||||
0x569,
|
||||
0x460,
|
||||
0xCA0,
|
||||
0xDA9,
|
||||
0xEA3,
|
||||
0xFAA,
|
||||
0x8A6,
|
||||
0x9AF,
|
||||
0xAA5,
|
||||
0xBAC,
|
||||
0x4AC,
|
||||
0x5A5,
|
||||
0x6AF,
|
||||
0x7A6,
|
||||
0xAA,
|
||||
0x1A3,
|
||||
0x2A9,
|
||||
0x3A0,
|
||||
0xD30,
|
||||
0xC39,
|
||||
0xF33,
|
||||
0xE3A,
|
||||
0x936,
|
||||
0x83F,
|
||||
0xB35,
|
||||
0xA3C,
|
||||
0x53C,
|
||||
0x435,
|
||||
0x73F,
|
||||
0x636,
|
||||
0x13A,
|
||||
0x33,
|
||||
0x339,
|
||||
0x230,
|
||||
0xE90,
|
||||
0xF99,
|
||||
0xC93,
|
||||
0xD9A,
|
||||
0xA96,
|
||||
0xB9F,
|
||||
0x895,
|
||||
0x99C,
|
||||
0x69C,
|
||||
0x795,
|
||||
0x49F,
|
||||
0x596,
|
||||
0x29A,
|
||||
0x393,
|
||||
0x99,
|
||||
0x190,
|
||||
0xF00,
|
||||
0xE09,
|
||||
0xD03,
|
||||
0xC0A,
|
||||
0xB06,
|
||||
0xA0F,
|
||||
0x905,
|
||||
0x80C,
|
||||
0x70C,
|
||||
0x605,
|
||||
0x50F,
|
||||
0x406,
|
||||
0x30A,
|
||||
0x203,
|
||||
0x109,
|
||||
0x0,
|
||||
]
|
||||
|
||||
# Maps each edge (by index) to the corresponding cube vertices
|
||||
EDGE_TO_VERTICES = [
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[3, 2],
|
||||
[0, 3],
|
||||
[4, 5],
|
||||
[5, 6],
|
||||
[7, 6],
|
||||
[4, 7],
|
||||
[0, 4],
|
||||
[1, 5],
|
||||
[2, 6],
|
||||
[3, 7],
|
||||
]
|
||||
|
||||
# A list of lists mapping a cube_index (a given configuration)
|
||||
# to a list of faces corresponding to that configuration. Each face is represented
|
||||
# by 3 consecutive numbers. A configuration will at most have 5 faces.
|
||||
#
|
||||
# Table taken from http://paulbourke.net/geometry/polygonise/
|
||||
FACE_TABLE = [
|
||||
[],
|
||||
[0, 8, 3],
|
||||
[0, 1, 9],
|
||||
[1, 8, 3, 9, 8, 1],
|
||||
[1, 2, 10],
|
||||
[0, 8, 3, 1, 2, 10],
|
||||
[9, 2, 10, 0, 2, 9],
|
||||
[2, 8, 3, 2, 10, 8, 10, 9, 8],
|
||||
[3, 11, 2],
|
||||
[0, 11, 2, 8, 11, 0],
|
||||
[1, 9, 0, 2, 3, 11],
|
||||
[1, 11, 2, 1, 9, 11, 9, 8, 11],
|
||||
[3, 10, 1, 11, 10, 3],
|
||||
[0, 10, 1, 0, 8, 10, 8, 11, 10],
|
||||
[3, 9, 0, 3, 11, 9, 11, 10, 9],
|
||||
[9, 8, 10, 10, 8, 11],
|
||||
[4, 7, 8],
|
||||
[4, 3, 0, 7, 3, 4],
|
||||
[0, 1, 9, 8, 4, 7],
|
||||
[4, 1, 9, 4, 7, 1, 7, 3, 1],
|
||||
[1, 2, 10, 8, 4, 7],
|
||||
[3, 4, 7, 3, 0, 4, 1, 2, 10],
|
||||
[9, 2, 10, 9, 0, 2, 8, 4, 7],
|
||||
[2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4],
|
||||
[8, 4, 7, 3, 11, 2],
|
||||
[11, 4, 7, 11, 2, 4, 2, 0, 4],
|
||||
[9, 0, 1, 8, 4, 7, 2, 3, 11],
|
||||
[4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1],
|
||||
[3, 10, 1, 3, 11, 10, 7, 8, 4],
|
||||
[1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4],
|
||||
[4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3],
|
||||
[4, 7, 11, 4, 11, 9, 9, 11, 10],
|
||||
[9, 5, 4],
|
||||
[9, 5, 4, 0, 8, 3],
|
||||
[0, 5, 4, 1, 5, 0],
|
||||
[8, 5, 4, 8, 3, 5, 3, 1, 5],
|
||||
[1, 2, 10, 9, 5, 4],
|
||||
[3, 0, 8, 1, 2, 10, 4, 9, 5],
|
||||
[5, 2, 10, 5, 4, 2, 4, 0, 2],
|
||||
[2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8],
|
||||
[9, 5, 4, 2, 3, 11],
|
||||
[0, 11, 2, 0, 8, 11, 4, 9, 5],
|
||||
[0, 5, 4, 0, 1, 5, 2, 3, 11],
|
||||
[2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5],
|
||||
[10, 3, 11, 10, 1, 3, 9, 5, 4],
|
||||
[4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10],
|
||||
[5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3],
|
||||
[5, 4, 8, 5, 8, 10, 10, 8, 11],
|
||||
[9, 7, 8, 5, 7, 9],
|
||||
[9, 3, 0, 9, 5, 3, 5, 7, 3],
|
||||
[0, 7, 8, 0, 1, 7, 1, 5, 7],
|
||||
[1, 5, 3, 3, 5, 7],
|
||||
[9, 7, 8, 9, 5, 7, 10, 1, 2],
|
||||
[10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3],
|
||||
[8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2],
|
||||
[2, 10, 5, 2, 5, 3, 3, 5, 7],
|
||||
[7, 9, 5, 7, 8, 9, 3, 11, 2],
|
||||
[9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11],
|
||||
[2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7],
|
||||
[11, 2, 1, 11, 1, 7, 7, 1, 5],
|
||||
[9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11],
|
||||
[5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0],
|
||||
[11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0],
|
||||
[11, 10, 5, 7, 11, 5],
|
||||
[10, 6, 5],
|
||||
[0, 8, 3, 5, 10, 6],
|
||||
[9, 0, 1, 5, 10, 6],
|
||||
[1, 8, 3, 1, 9, 8, 5, 10, 6],
|
||||
[1, 6, 5, 2, 6, 1],
|
||||
[1, 6, 5, 1, 2, 6, 3, 0, 8],
|
||||
[9, 6, 5, 9, 0, 6, 0, 2, 6],
|
||||
[5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8],
|
||||
[2, 3, 11, 10, 6, 5],
|
||||
[11, 0, 8, 11, 2, 0, 10, 6, 5],
|
||||
[0, 1, 9, 2, 3, 11, 5, 10, 6],
|
||||
[5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11],
|
||||
[6, 3, 11, 6, 5, 3, 5, 1, 3],
|
||||
[0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6],
|
||||
[3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9],
|
||||
[6, 5, 9, 6, 9, 11, 11, 9, 8],
|
||||
[5, 10, 6, 4, 7, 8],
|
||||
[4, 3, 0, 4, 7, 3, 6, 5, 10],
|
||||
[1, 9, 0, 5, 10, 6, 8, 4, 7],
|
||||
[10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4],
|
||||
[6, 1, 2, 6, 5, 1, 4, 7, 8],
|
||||
[1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7],
|
||||
[8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6],
|
||||
[7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9],
|
||||
[3, 11, 2, 7, 8, 4, 10, 6, 5],
|
||||
[5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11],
|
||||
[0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6],
|
||||
[9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6],
|
||||
[8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6],
|
||||
[5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11],
|
||||
[0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7],
|
||||
[6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9],
|
||||
[10, 4, 9, 6, 4, 10],
|
||||
[4, 10, 6, 4, 9, 10, 0, 8, 3],
|
||||
[10, 0, 1, 10, 6, 0, 6, 4, 0],
|
||||
[8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10],
|
||||
[1, 4, 9, 1, 2, 4, 2, 6, 4],
|
||||
[3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4],
|
||||
[0, 2, 4, 4, 2, 6],
|
||||
[8, 3, 2, 8, 2, 4, 4, 2, 6],
|
||||
[10, 4, 9, 10, 6, 4, 11, 2, 3],
|
||||
[0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6],
|
||||
[3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10],
|
||||
[6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1],
|
||||
[9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3],
|
||||
[8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1],
|
||||
[3, 11, 6, 3, 6, 0, 0, 6, 4],
|
||||
[6, 4, 8, 11, 6, 8],
|
||||
[7, 10, 6, 7, 8, 10, 8, 9, 10],
|
||||
[0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10],
|
||||
[10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0],
|
||||
[10, 6, 7, 10, 7, 1, 1, 7, 3],
|
||||
[1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7],
|
||||
[2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9],
|
||||
[7, 8, 0, 7, 0, 6, 6, 0, 2],
|
||||
[7, 3, 2, 6, 7, 2],
|
||||
[2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7],
|
||||
[2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7],
|
||||
[1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11],
|
||||
[11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1],
|
||||
[8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6],
|
||||
[0, 9, 1, 11, 6, 7],
|
||||
[7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0],
|
||||
[7, 11, 6],
|
||||
[7, 6, 11],
|
||||
[3, 0, 8, 11, 7, 6],
|
||||
[0, 1, 9, 11, 7, 6],
|
||||
[8, 1, 9, 8, 3, 1, 11, 7, 6],
|
||||
[10, 1, 2, 6, 11, 7],
|
||||
[1, 2, 10, 3, 0, 8, 6, 11, 7],
|
||||
[2, 9, 0, 2, 10, 9, 6, 11, 7],
|
||||
[6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8],
|
||||
[7, 2, 3, 6, 2, 7],
|
||||
[7, 0, 8, 7, 6, 0, 6, 2, 0],
|
||||
[2, 7, 6, 2, 3, 7, 0, 1, 9],
|
||||
[1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6],
|
||||
[10, 7, 6, 10, 1, 7, 1, 3, 7],
|
||||
[10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8],
|
||||
[0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7],
|
||||
[7, 6, 10, 7, 10, 8, 8, 10, 9],
|
||||
[6, 8, 4, 11, 8, 6],
|
||||
[3, 6, 11, 3, 0, 6, 0, 4, 6],
|
||||
[8, 6, 11, 8, 4, 6, 9, 0, 1],
|
||||
[9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6],
|
||||
[6, 8, 4, 6, 11, 8, 2, 10, 1],
|
||||
[1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6],
|
||||
[4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9],
|
||||
[10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3],
|
||||
[8, 2, 3, 8, 4, 2, 4, 6, 2],
|
||||
[0, 4, 2, 4, 6, 2],
|
||||
[1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8],
|
||||
[1, 9, 4, 1, 4, 2, 2, 4, 6],
|
||||
[8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1],
|
||||
[10, 1, 0, 10, 0, 6, 6, 0, 4],
|
||||
[4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3],
|
||||
[10, 9, 4, 6, 10, 4],
|
||||
[4, 9, 5, 7, 6, 11],
|
||||
[0, 8, 3, 4, 9, 5, 11, 7, 6],
|
||||
[5, 0, 1, 5, 4, 0, 7, 6, 11],
|
||||
[11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5],
|
||||
[9, 5, 4, 10, 1, 2, 7, 6, 11],
|
||||
[6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5],
|
||||
[7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2],
|
||||
[3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6],
|
||||
[7, 2, 3, 7, 6, 2, 5, 4, 9],
|
||||
[9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7],
|
||||
[3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0],
|
||||
[6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8],
|
||||
[9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7],
|
||||
[1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4],
|
||||
[4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10],
|
||||
[7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10],
|
||||
[6, 9, 5, 6, 11, 9, 11, 8, 9],
|
||||
[3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5],
|
||||
[0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11],
|
||||
[6, 11, 3, 6, 3, 5, 5, 3, 1],
|
||||
[1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6],
|
||||
[0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10],
|
||||
[11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5],
|
||||
[6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3],
|
||||
[5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2],
|
||||
[9, 5, 6, 9, 6, 0, 0, 6, 2],
|
||||
[1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8],
|
||||
[1, 5, 6, 2, 1, 6],
|
||||
[1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6],
|
||||
[10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0],
|
||||
[0, 3, 8, 5, 6, 10],
|
||||
[10, 5, 6],
|
||||
[11, 5, 10, 7, 5, 11],
|
||||
[11, 5, 10, 11, 7, 5, 8, 3, 0],
|
||||
[5, 11, 7, 5, 10, 11, 1, 9, 0],
|
||||
[10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1],
|
||||
[11, 1, 2, 11, 7, 1, 7, 5, 1],
|
||||
[0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11],
|
||||
[9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7],
|
||||
[7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2],
|
||||
[2, 5, 10, 2, 3, 5, 3, 7, 5],
|
||||
[8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5],
|
||||
[9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2],
|
||||
[9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2],
|
||||
[1, 3, 5, 3, 7, 5],
|
||||
[0, 8, 7, 0, 7, 1, 1, 7, 5],
|
||||
[9, 0, 3, 9, 3, 5, 5, 3, 7],
|
||||
[9, 8, 7, 5, 9, 7],
|
||||
[5, 8, 4, 5, 10, 8, 10, 11, 8],
|
||||
[5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0],
|
||||
[0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5],
|
||||
[10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4],
|
||||
[2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8],
|
||||
[0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11],
|
||||
[0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5],
|
||||
[9, 4, 5, 2, 11, 3],
|
||||
[2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4],
|
||||
[5, 10, 2, 5, 2, 4, 4, 2, 0],
|
||||
[3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9],
|
||||
[5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2],
|
||||
[8, 4, 5, 8, 5, 3, 3, 5, 1],
|
||||
[0, 4, 5, 1, 0, 5],
|
||||
[8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5],
|
||||
[9, 4, 5],
|
||||
[4, 11, 7, 4, 9, 11, 9, 10, 11],
|
||||
[0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11],
|
||||
[1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11],
|
||||
[3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4],
|
||||
[4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2],
|
||||
[9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3],
|
||||
[11, 7, 4, 11, 4, 2, 2, 4, 0],
|
||||
[11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4],
|
||||
[2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9],
|
||||
[9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7],
|
||||
[3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10],
|
||||
[1, 10, 2, 8, 7, 4],
|
||||
[4, 9, 1, 4, 1, 7, 7, 1, 3],
|
||||
[4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1],
|
||||
[4, 0, 3, 7, 4, 3],
|
||||
[4, 8, 7],
|
||||
[9, 10, 8, 10, 11, 8],
|
||||
[3, 0, 9, 3, 9, 11, 11, 9, 10],
|
||||
[0, 1, 10, 0, 10, 8, 8, 10, 11],
|
||||
[3, 1, 10, 11, 3, 10],
|
||||
[1, 2, 11, 1, 11, 9, 9, 11, 8],
|
||||
[3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9],
|
||||
[0, 2, 11, 8, 0, 11],
|
||||
[3, 2, 11],
|
||||
[2, 3, 8, 2, 8, 10, 10, 8, 9],
|
||||
[9, 10, 2, 0, 9, 2],
|
||||
[2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8],
|
||||
[1, 10, 2],
|
||||
[1, 3, 8, 9, 1, 8],
|
||||
[0, 9, 1],
|
||||
[0, 3, 8],
|
||||
[],
|
||||
]
|
@ -414,7 +414,7 @@ class Transform3d:
|
||||
|
||||
|
||||
class Translate(Transform3d):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
|
||||
"""
|
||||
Create a new Transform3d representing 3D translations.
|
||||
|
||||
@ -448,7 +448,7 @@ class Translate(Transform3d):
|
||||
|
||||
|
||||
class Scale(Transform3d):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
|
||||
"""
|
||||
A Transform3d representing a scaling operation, with different scale
|
||||
factors along each coordinate axis.
|
||||
@ -489,7 +489,7 @@ class Scale(Transform3d):
|
||||
|
||||
class Rotate(Transform3d):
|
||||
def __init__(
|
||||
self, R, dtype=torch.float32, device: str = "cpu", orthogonal_tol: float = 1e-5
|
||||
self, R, dtype=torch.float32, device="cpu", orthogonal_tol: float = 1e-5
|
||||
):
|
||||
"""
|
||||
Create a new Transform3d representing 3D rotation using a rotation
|
||||
@ -528,7 +528,7 @@ class RotateAxisAngle(Rotate):
|
||||
axis: str = "X",
|
||||
degrees: bool = True,
|
||||
dtype=torch.float64,
|
||||
device: str = "cpu",
|
||||
device="cpu",
|
||||
):
|
||||
"""
|
||||
Create a new Transform3d representing 3D rotation about an axis
|
||||
@ -635,7 +635,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
|
||||
return xyz
|
||||
|
||||
|
||||
def _handle_angle_input(x, dtype, device: str, name: str):
|
||||
def _handle_angle_input(x, dtype, device, name: str):
|
||||
"""
|
||||
Helper function for building a rotation function using angles.
|
||||
The output is always of shape (N,).
|
||||
|
25
tests/bm_marching_cubes.py
Normal file
25
tests/bm_marching_cubes.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_marching_cubes import TestMarchingCubes
|
||||
|
||||
|
||||
def bm_marching_cubes() -> None:
|
||||
kwargs_list = [
|
||||
{"batch_size": 1, "V": 5},
|
||||
{"batch_size": 1, "V": 10},
|
||||
{"batch_size": 1, "V": 20},
|
||||
{"batch_size": 1, "V": 40},
|
||||
{"batch_size": 5, "V": 5},
|
||||
{"batch_size": 20, "V": 20},
|
||||
]
|
||||
benchmark(
|
||||
TestMarchingCubes.marching_cubes_with_init,
|
||||
"MARCHING_CUBES",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_marching_cubes()
|
BIN
tests/data/test_marching_cubes_data/double_ellipsoid.pickle
Normal file
BIN
tests/data/test_marching_cubes_data/double_ellipsoid.pickle
Normal file
Binary file not shown.
BIN
tests/data/test_marching_cubes_data/sphere_level64.pickle
Normal file
BIN
tests/data/test_marching_cubes_data/sphere_level64.pickle
Normal file
Binary file not shown.
771
tests/test_marching_cubes.py
Normal file
771
tests/test_marching_cubes.py
Normal file
@ -0,0 +1,771 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops.marching_cubes import marching_cubes_naive
|
||||
|
||||
|
||||
USE_SCIKIT = False
|
||||
|
||||
|
||||
def convert_to_local(verts, volume_dim):
|
||||
return (2 * verts) / (volume_dim - 1) - 1
|
||||
|
||||
|
||||
class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# Test single cubes. Each case corresponds to the corresponding
|
||||
# cube vertex configuration in each case here (0-indexed):
|
||||
# https://en.wikipedia.org/wiki/Marching_cubes#/media/File:MarchingCubes.svg
|
||||
|
||||
def test_empty_volume(self): # case 0
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor([])
|
||||
expected_faces = torch.tensor([], dtype=torch.int64)
|
||||
self.assertClose(verts, expected_verts)
|
||||
self.assertClose(faces, expected_faces)
|
||||
|
||||
def test_case1(self): # case 1
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 0, 0, 0] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.5, 0, 0],
|
||||
[0, 0, 0.5],
|
||||
[0, 0.5, 0],
|
||||
]
|
||||
)
|
||||
|
||||
expected_faces = torch.tensor([[1, 2, 0]])
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case2(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 0:2, 0, 0] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[1.0000, 0.0000, 0.5000],
|
||||
[0.0000, 0.0000, 0.5000],
|
||||
[1.0000, 0.5000, 0.0000],
|
||||
[0.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
expected_faces = torch.tensor([[1, 2, 0], [3, 2, 1]])
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case3(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 0, 0, 0] = 0
|
||||
volume_data[0, 1, 1, 0] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.5000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.5000],
|
||||
[1.0000, 1.0000, 0.5000],
|
||||
[0.5000, 1.0000, 0.0000],
|
||||
[1.0000, 0.5000, 0.0000],
|
||||
[0.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
expected_faces = torch.tensor([[0, 1, 5], [4, 3, 2]])
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case4(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 1, 0, 0] = 0
|
||||
volume_data[0, 1, 0, 1] = 0
|
||||
volume_data[0, 0, 0, 1] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.5000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.5000],
|
||||
[0.0000, 0.5000, 1.0000],
|
||||
[1.0000, 0.5000, 1.0000],
|
||||
[1.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
expected_faces = torch.tensor([[0, 2, 1], [0, 4, 2], [4, 3, 2]])
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case5(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 0:2, 0, 0:2] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.0000, 0.5000, 1.0000],
|
||||
[1.0000, 0.5000, 1.0000],
|
||||
[1.0000, 0.5000, 0.0000],
|
||||
[0.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
|
||||
expected_faces = torch.tensor([[1, 0, 2], [2, 0, 3]])
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case6(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 1, 0, 0] = 0
|
||||
volume_data[0, 1, 0, 1] = 0
|
||||
volume_data[0, 0, 0, 1] = 0
|
||||
volume_data[0, 0, 1, 0] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.5000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.5000],
|
||||
[0.5000, 1.0000, 0.0000],
|
||||
[0.0000, 1.0000, 0.5000],
|
||||
[0.0000, 0.5000, 1.0000],
|
||||
[1.0000, 0.5000, 1.0000],
|
||||
[1.0000, 0.5000, 0.0000],
|
||||
[0.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
expected_faces = torch.tensor([[2, 7, 3], [0, 6, 1], [6, 4, 1], [6, 5, 4]])
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case7(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 0, 0, 0] = 0
|
||||
volume_data[0, 1, 0, 1] = 0
|
||||
volume_data[0, 1, 1, 0] = 0
|
||||
volume_data[0, 0, 1, 1] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.5000, 0.0000, 1.0000],
|
||||
[1.0000, 0.0000, 0.5000],
|
||||
[0.5000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.5000],
|
||||
[0.5000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 0.5000],
|
||||
[0.5000, 1.0000, 0.0000],
|
||||
[0.0000, 1.0000, 0.5000],
|
||||
[0.0000, 0.5000, 1.0000],
|
||||
[1.0000, 0.5000, 1.0000],
|
||||
[1.0000, 0.5000, 0.0000],
|
||||
[0.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
|
||||
expected_faces = torch.tensor([[0, 1, 9], [4, 7, 8], [2, 3, 11], [5, 10, 6]])
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case8(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 0, 0, 0] = 0
|
||||
volume_data[0, 0, 0, 1] = 0
|
||||
volume_data[0, 1, 0, 1] = 0
|
||||
volume_data[0, 0, 1, 1] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[1.0000, 0.0000, 0.5000],
|
||||
[0.5000, 0.0000, 0.0000],
|
||||
[0.5000, 1.0000, 1.0000],
|
||||
[0.0000, 1.0000, 0.5000],
|
||||
[1.0000, 0.5000, 1.0000],
|
||||
[0.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
expected_faces = torch.tensor([[2, 3, 5], [4, 2, 5], [4, 5, 1], [4, 1, 0]])
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case9(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 1, 0, 0] = 0
|
||||
volume_data[0, 0, 0, 1] = 0
|
||||
volume_data[0, 1, 0, 1] = 0
|
||||
volume_data[0, 0, 1, 1] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.5000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.5000],
|
||||
[0.5000, 1.0000, 1.0000],
|
||||
[0.0000, 1.0000, 0.5000],
|
||||
[1.0000, 0.5000, 1.0000],
|
||||
[1.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
expected_faces = torch.tensor([[0, 5, 4], [0, 4, 3], [0, 3, 1], [3, 4, 2]])
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case10(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 0, 0, 0] = 0
|
||||
volume_data[0, 1, 1, 1] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.5000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.5000],
|
||||
[0.5000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 0.5000],
|
||||
[1.0000, 0.5000, 1.0000],
|
||||
[0.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
|
||||
expected_faces = torch.tensor([[4, 3, 2], [0, 1, 5]])
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case11(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 0, 0, 0] = 0
|
||||
volume_data[0, 1, 0, 0] = 0
|
||||
volume_data[0, 1, 1, 1] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[1.0000, 0.0000, 0.5000],
|
||||
[0.0000, 0.0000, 0.5000],
|
||||
[0.5000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 0.5000],
|
||||
[1.0000, 0.5000, 1.0000],
|
||||
[1.0000, 0.5000, 0.0000],
|
||||
[0.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
|
||||
expected_faces = torch.tensor([[5, 1, 6], [5, 0, 1], [4, 3, 2]])
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case12(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 1, 0, 0] = 0
|
||||
volume_data[0, 0, 1, 0] = 0
|
||||
volume_data[0, 1, 1, 1] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[1.0000, 0.0000, 0.5000],
|
||||
[0.5000, 0.0000, 0.0000],
|
||||
[0.5000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 0.5000],
|
||||
[0.5000, 1.0000, 0.0000],
|
||||
[0.0000, 1.0000, 0.5000],
|
||||
[1.0000, 0.5000, 1.0000],
|
||||
[1.0000, 0.5000, 0.0000],
|
||||
[0.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
|
||||
expected_faces = torch.tensor([[6, 3, 2], [7, 0, 1], [5, 4, 8]])
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case13(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 0, 0, 0] = 0
|
||||
volume_data[0, 0, 1, 0] = 0
|
||||
volume_data[0, 1, 0, 1] = 0
|
||||
volume_data[0, 1, 1, 1] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.5000, 0.0000, 1.0000],
|
||||
[1.0000, 0.0000, 0.5000],
|
||||
[0.5000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.5000],
|
||||
[0.5000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 0.5000],
|
||||
[0.5000, 1.0000, 0.0000],
|
||||
[0.0000, 1.0000, 0.5000],
|
||||
]
|
||||
)
|
||||
|
||||
expected_faces = torch.tensor([[3, 6, 2], [3, 7, 6], [1, 5, 0], [5, 4, 0]])
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
def test_case14(self):
|
||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||
volume_data[0, 0, 0, 0] = 0
|
||||
volume_data[0, 0, 0, 1] = 0
|
||||
volume_data[0, 1, 0, 1] = 0
|
||||
volume_data[0, 1, 1, 1] = 0
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[1.0000, 0.0000, 0.5000],
|
||||
[0.5000, 0.0000, 0.0000],
|
||||
[0.5000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 0.5000],
|
||||
[0.0000, 0.5000, 1.0000],
|
||||
[0.0000, 0.5000, 0.0000],
|
||||
]
|
||||
)
|
||||
|
||||
expected_faces = torch.tensor([[1, 0, 3], [1, 3, 4], [1, 4, 5], [2, 4, 3]])
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 2)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
|
||||
class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
def test_single_point(self):
|
||||
volume_data = torch.zeros(1, 3, 3, 3) # (B, W, H, D)
|
||||
volume_data[0, 1, 1, 1] = 1
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.5, 1, 1],
|
||||
[1, 1, 0.5],
|
||||
[1, 0.5, 1],
|
||||
[1, 1, 1.5],
|
||||
[1, 1.5, 1],
|
||||
[1.5, 1, 1],
|
||||
]
|
||||
)
|
||||
expected_faces = torch.tensor(
|
||||
[
|
||||
[2, 0, 1],
|
||||
[2, 3, 0],
|
||||
[0, 4, 1],
|
||||
[3, 4, 0],
|
||||
[5, 2, 1],
|
||||
[3, 2, 5],
|
||||
[5, 1, 4],
|
||||
[3, 5, 4],
|
||||
]
|
||||
)
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 3)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
|
||||
|
||||
def test_cube(self):
|
||||
volume_data = torch.zeros(1, 5, 5, 5) # (B, W, H, D)
|
||||
volume_data[0, 1, 1, 1] = 1
|
||||
volume_data[0, 1, 1, 2] = 1
|
||||
volume_data[0, 2, 1, 1] = 1
|
||||
volume_data[0, 2, 1, 2] = 1
|
||||
volume_data[0, 1, 2, 1] = 1
|
||||
volume_data[0, 1, 2, 2] = 1
|
||||
volume_data[0, 2, 2, 1] = 1
|
||||
volume_data[0, 2, 2, 2] = 1
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[0.9000, 1.0000, 1.0000],
|
||||
[1.0000, 1.0000, 0.9000],
|
||||
[1.0000, 0.9000, 1.0000],
|
||||
[0.9000, 1.0000, 2.0000],
|
||||
[1.0000, 0.9000, 2.0000],
|
||||
[1.0000, 1.0000, 2.1000],
|
||||
[0.9000, 2.0000, 1.0000],
|
||||
[1.0000, 2.0000, 0.9000],
|
||||
[0.9000, 2.0000, 2.0000],
|
||||
[1.0000, 2.0000, 2.1000],
|
||||
[1.0000, 2.1000, 1.0000],
|
||||
[1.0000, 2.1000, 2.0000],
|
||||
[2.0000, 1.0000, 0.9000],
|
||||
[2.0000, 0.9000, 1.0000],
|
||||
[2.0000, 0.9000, 2.0000],
|
||||
[2.0000, 1.0000, 2.1000],
|
||||
[2.0000, 2.0000, 0.9000],
|
||||
[2.0000, 2.0000, 2.1000],
|
||||
[2.0000, 2.1000, 1.0000],
|
||||
[2.0000, 2.1000, 2.0000],
|
||||
[2.1000, 1.0000, 1.0000],
|
||||
[2.1000, 1.0000, 2.0000],
|
||||
[2.1000, 2.0000, 1.0000],
|
||||
[2.1000, 2.0000, 2.0000],
|
||||
]
|
||||
)
|
||||
|
||||
expected_faces = torch.tensor(
|
||||
[
|
||||
[2, 0, 1],
|
||||
[2, 4, 3],
|
||||
[0, 2, 3],
|
||||
[4, 5, 3],
|
||||
[0, 6, 7],
|
||||
[1, 0, 7],
|
||||
[3, 8, 0],
|
||||
[8, 6, 0],
|
||||
[5, 9, 8],
|
||||
[3, 5, 8],
|
||||
[6, 10, 7],
|
||||
[11, 10, 6],
|
||||
[8, 11, 6],
|
||||
[9, 11, 8],
|
||||
[13, 2, 1],
|
||||
[12, 13, 1],
|
||||
[14, 4, 13],
|
||||
[13, 4, 2],
|
||||
[4, 14, 15],
|
||||
[5, 4, 15],
|
||||
[12, 1, 16],
|
||||
[1, 7, 16],
|
||||
[15, 17, 5],
|
||||
[5, 17, 9],
|
||||
[16, 7, 10],
|
||||
[18, 16, 10],
|
||||
[19, 18, 11],
|
||||
[18, 10, 11],
|
||||
[9, 17, 19],
|
||||
[11, 9, 19],
|
||||
[20, 13, 12],
|
||||
[20, 21, 14],
|
||||
[13, 20, 14],
|
||||
[15, 14, 21],
|
||||
[22, 20, 12],
|
||||
[16, 22, 12],
|
||||
[21, 20, 23],
|
||||
[23, 20, 22],
|
||||
[17, 15, 21],
|
||||
[23, 17, 21],
|
||||
[22, 16, 18],
|
||||
[23, 22, 18],
|
||||
[19, 23, 18],
|
||||
[17, 23, 19],
|
||||
]
|
||||
)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 5)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# Check all values are in the range [-1, 1]
|
||||
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
|
||||
|
||||
def test_cube_no_duplicate_verts(self):
|
||||
volume_data = torch.zeros(1, 5, 5, 5) # (B, W, H, D)
|
||||
volume_data[0, 1, 1, 1] = 1
|
||||
volume_data[0, 1, 1, 2] = 1
|
||||
volume_data[0, 2, 1, 1] = 1
|
||||
volume_data[0, 2, 1, 2] = 1
|
||||
volume_data[0, 1, 2, 1] = 1
|
||||
volume_data[0, 1, 2, 2] = 1
|
||||
volume_data[0, 2, 2, 1] = 1
|
||||
volume_data[0, 2, 2, 2] = 1
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, 1, return_local_coords=False)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[
|
||||
[1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 2.0],
|
||||
[1.0, 2.0, 1.0],
|
||||
[1.0, 2.0, 2.0],
|
||||
[2.0, 1.0, 1.0],
|
||||
[2.0, 1.0, 2.0],
|
||||
[2.0, 2.0, 1.0],
|
||||
[2.0, 2.0, 2.0],
|
||||
]
|
||||
)
|
||||
|
||||
expected_faces = torch.tensor(
|
||||
[
|
||||
[1, 3, 0],
|
||||
[3, 2, 0],
|
||||
[5, 1, 4],
|
||||
[4, 1, 0],
|
||||
[4, 0, 6],
|
||||
[0, 2, 6],
|
||||
[5, 7, 1],
|
||||
[1, 7, 3],
|
||||
[7, 6, 3],
|
||||
[6, 2, 3],
|
||||
[5, 4, 7],
|
||||
[7, 4, 6],
|
||||
]
|
||||
)
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(volume_data, 1, return_local_coords=True)
|
||||
expected_verts = convert_to_local(expected_verts, 5)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# Check all values are in the range [-1, 1]
|
||||
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
|
||||
|
||||
def test_sphere(self):
|
||||
# (B, W, H, D)
|
||||
volume = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[(x - 10) ** 2 + (y - 10) ** 2 + (z - 10) ** 2 for z in range(20)]
|
||||
for y in range(20)
|
||||
]
|
||||
for x in range(20)
|
||||
]
|
||||
).unsqueeze(0)
|
||||
volume = volume.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(
|
||||
volume, isolevel=64, return_local_coords=False
|
||||
)
|
||||
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
data_filename = "test_marching_cubes_data/sphere_level64.pickle"
|
||||
filename = os.path.join(DATA_DIR, data_filename)
|
||||
with open(filename, "rb") as file:
|
||||
verts_and_faces = pickle.load(file)
|
||||
expected_verts = verts_and_faces["verts"].squeeze()
|
||||
expected_faces = verts_and_faces["faces"].squeeze()
|
||||
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
verts, faces = marching_cubes_naive(
|
||||
volume, isolevel=64, return_local_coords=True
|
||||
)
|
||||
|
||||
expected_verts = convert_to_local(expected_verts, 20)
|
||||
self.assertClose(verts[0], expected_verts)
|
||||
self.assertClose(faces[0], expected_faces)
|
||||
|
||||
# Check all values are in the range [-1, 1]
|
||||
self.assertTrue(verts[0].ge(-1).all() and verts[0].le(1).all())
|
||||
|
||||
# Uses skimage.draw.ellipsoid
|
||||
def test_double_ellipsoid(self):
|
||||
if USE_SCIKIT:
|
||||
import numpy as np
|
||||
from skimage.draw import ellipsoid
|
||||
|
||||
ellip_base = ellipsoid(6, 10, 16, levelset=True)
|
||||
ellip_double = np.concatenate(
|
||||
(ellip_base[:-1, ...], ellip_base[2:, ...]), axis=0
|
||||
)
|
||||
volume = torch.Tensor(ellip_double).unsqueeze(0)
|
||||
volume = volume.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume, isolevel=0.001)
|
||||
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
data_filename = "test_marching_cubes_data/double_ellipsoid.pickle"
|
||||
filename = os.path.join(DATA_DIR, data_filename)
|
||||
with open(filename, "rb") as file:
|
||||
verts_and_faces = pickle.load(file)
|
||||
expected_verts = verts_and_faces["verts"]
|
||||
expected_faces = verts_and_faces["faces"]
|
||||
|
||||
self.assertClose(verts[0], expected_verts[0])
|
||||
self.assertClose(faces[0], expected_faces[0])
|
||||
|
||||
def test_cube_surface_area(self):
|
||||
if USE_SCIKIT:
|
||||
from skimage.measure import marching_cubes_classic, mesh_surface_area
|
||||
|
||||
volume_data = torch.zeros(1, 5, 5, 5)
|
||||
volume_data[0, 1, 1, 1] = 1
|
||||
volume_data[0, 1, 1, 2] = 1
|
||||
volume_data[0, 2, 1, 1] = 1
|
||||
volume_data[0, 2, 1, 2] = 1
|
||||
volume_data[0, 1, 2, 1] = 1
|
||||
volume_data[0, 1, 2, 2] = 1
|
||||
volume_data[0, 2, 2, 1] = 1
|
||||
volume_data[0, 2, 2, 2] = 1
|
||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
verts_sci, faces_sci = marching_cubes_classic(volume_data[0])
|
||||
|
||||
surf = mesh_surface_area(verts[0], faces[0])
|
||||
surf_sci = mesh_surface_area(verts_sci, faces_sci)
|
||||
|
||||
self.assertClose(surf, surf_sci)
|
||||
|
||||
def test_sphere_surface_area(self):
|
||||
if USE_SCIKIT:
|
||||
from skimage.measure import marching_cubes_classic, mesh_surface_area
|
||||
|
||||
# (B, W, H, D)
|
||||
volume = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
(x - 10) ** 2 + (y - 10) ** 2 + (z - 10) ** 2
|
||||
for z in range(20)
|
||||
]
|
||||
for y in range(20)
|
||||
]
|
||||
for x in range(20)
|
||||
]
|
||||
).unsqueeze(0)
|
||||
volume = volume.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume, isolevel=64)
|
||||
verts_sci, faces_sci = marching_cubes_classic(volume[0], level=64)
|
||||
|
||||
surf = mesh_surface_area(verts[0], faces[0])
|
||||
surf_sci = mesh_surface_area(verts_sci, faces_sci)
|
||||
|
||||
self.assertClose(surf, surf_sci)
|
||||
|
||||
def test_double_ellipsoid_surface_area(self):
|
||||
if USE_SCIKIT:
|
||||
import numpy as np
|
||||
from skimage.draw import ellipsoid
|
||||
from skimage.measure import marching_cubes_classic, mesh_surface_area
|
||||
|
||||
ellip_base = ellipsoid(6, 10, 16, levelset=True)
|
||||
ellip_double = np.concatenate(
|
||||
(ellip_base[:-1, ...], ellip_base[2:, ...]), axis=0
|
||||
)
|
||||
volume = torch.Tensor(ellip_double).unsqueeze(0)
|
||||
volume = volume.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||
verts, faces = marching_cubes_naive(volume, isolevel=0)
|
||||
verts_sci, faces_sci = marching_cubes_classic(volume[0], level=0)
|
||||
|
||||
surf = mesh_surface_area(verts[0], faces[0])
|
||||
surf_sci = mesh_surface_area(verts_sci, faces_sci)
|
||||
|
||||
self.assertClose(surf, surf_sci)
|
||||
|
||||
@staticmethod
|
||||
def marching_cubes_with_init(batch_size: int, V: int):
|
||||
device = torch.device("cuda:0")
|
||||
volume_data = torch.rand(
|
||||
(batch_size, V, V, V), dtype=torch.float32, device=device
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def convert():
|
||||
marching_cubes_naive(volume_data, return_local_coords=False)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return convert
|
Loading…
x
Reference in New Issue
Block a user