Python marching cubes improvements

Summary: Overhaul of marching_cubes_naive for better performance and to avoid relying on unstable hashing. In particular, instead of hashing vertex positions, we index each interpolated vertex with its corresponding edge in the 3d grid.

Reviewed By: kjchalup

Differential Revision: D39419642

fbshipit-source-id: b5fede3525c545d1d374198928dfb216262f0ec0
This commit is contained in:
Jiali Duan 2022-10-06 11:08:49 -07:00 committed by Facebook GitHub Bot
parent 6471893f59
commit 850efdf706
5 changed files with 342 additions and 703 deletions

View File

@ -4,10 +4,10 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from pytorch3d.ops.marching_cubes_data import EDGE_TABLE, EDGE_TO_VERTICES, FACE_TABLE from pytorch3d.ops.marching_cubes_data import EDGE_TO_VERTICES, FACE_TABLE, INDEX
from pytorch3d.transforms import Translate from pytorch3d.transforms import Translate
@ -15,10 +15,15 @@ EPS = 0.00001
class Cube: class Cube:
def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1) -> None: def __init__(
self,
bfl_v: Tuple[int, int, int],
volume: torch.Tensor,
isolevel: float,
) -> None:
""" """
Initializes a cube given the bottom front left vertex coordinate Initializes a cube given the bottom front left vertex coordinate
and the cube spacing and computes the cube configuration given vertex values and isolevel.
Edge and vertex convention: Edge and vertex convention:
@ -31,8 +36,8 @@ class Cube:
| | | | | | | |
| |e8 |e10| | |e8 |e10|
e11| | | | e11| | | |
| |_________________|___| | |______e0_________|___|
| / v0 e0 | /v1 | / v0(bfl_v) | |v1
| / | / | / | /
| /e3 | /e1 | /e3 | /e1
|/_____________________|/ |/_____________________|/
@ -41,311 +46,182 @@ class Cube:
Args: Args:
bfl_vertex: a tuple of size 3 corresponding to the bottom front left vertex bfl_vertex: a tuple of size 3 corresponding to the bottom front left vertex
of the cube in (x, y, z) format of the cube in (x, y, z) format
spacing: the length of each edge of the cube volume: the 3D scalar data
isolevel: the isosurface value used as a threshold for determining whether a point
is inside/outside the volume
""" """
# match corner orders to algorithm convention x, y, z = bfl_v
if len(bfl_vertex) != 3: self.x, self.y, self.z = x, y, z
msg = "The vertex {} is size {} instead of size 3".format( self.bfl_v = bfl_v
bfl_vertex, len(bfl_vertex) self.verts = [
) [x + (v & 1), y + (v >> 1 & 1), z + (v >> 2 & 1)] for v in range(8)
raise ValueError(msg) ] # vertex position (x, y, z) for v0-v1-v4-v5-v3-v2-v7-v6
x, y, z = bfl_vertex # Calculates cube configuration index given values of the cube vertices
self.vertices = torch.tensor( self.cube_index = 0
[ for i in range(8):
[x, y, z + spacing], v = self.verts[INDEX[i]]
[x + spacing, y, z + spacing], value = volume[v[2]][v[1]][v[0]]
[x + spacing, y, z],
[x, y, z],
[x, y + spacing, z + spacing],
[x + spacing, y + spacing, z + spacing],
[x + spacing, y + spacing, z],
[x, y + spacing, z],
]
)
def get_index(self, volume_data: torch.Tensor, isolevel: float) -> int:
"""
Calculates the cube_index in the range 0-255 to index
into EDGE_TABLE and FACE_TABLE
Args:
volume_data: the 3D scalar data
isolevel: the isosurface value used as a threshold
for determining whether a point is inside/outside
the volume
"""
cube_index = 0
bit = 1
for index in range(len(self.vertices)):
vertex = self.vertices[index]
value = _get_value(vertex, volume_data)
if value < isolevel: if value < isolevel:
cube_index |= bit self.cube_index |= 1 << i
bit *= 2
return cube_index def get_vpair_from_edge(self, edge: int, W: int, H: int) -> Tuple[int, int]:
"""
Get a tuple of global vertex ID from a local edge ID
Global vertex ID is calculated as (x + dx) + (y + dy) * W + (z + dz) * W * H
Args:
edge: local edge ID in the cube
bfl_vertex: bottom-front-left coordinate of the cube
Returns:
a pair of global vertex ID
"""
v1, v2 = EDGE_TO_VERTICES[edge] # two end-points on the edge
v1_id = self.verts[v1][0] + self.verts[v1][1] * W + self.verts[v1][2] * W * H
v2_id = self.verts[v2][0] + self.verts[v2][1] * W + self.verts[v2][2] * W * H
return (v1_id, v2_id)
def vert_interp(
self,
isolevel: float,
edge: int,
vol: torch.Tensor,
) -> List:
"""
Linearly interpolate a vertex where an isosurface cuts an edge
between the two endpoint vertices, based on their values
Args:
isolevel: the isosurface value to use as the threshold to determine
whether points are within a volume.
edge: edge (ID) to interpolate
cube: current cube vertices
vol: 3D scalar field
Returns:
interpolated vertex: position of the interpolated vertex on the edge
"""
v1, v2 = EDGE_TO_VERTICES[edge]
p1, p2 = self.verts[v1], self.verts[v2]
val1, val2 = (
vol[p1[2]][p1[1]][p1[0]],
vol[p2[2]][p2[1]][p2[0]],
)
point = None
if abs(isolevel - val1) < EPS:
point = p1
elif abs(isolevel - val2) < EPS:
point = p2
elif abs(val1 - val2) < EPS:
point = p1
if point is None:
mu = (isolevel - val1) / (val2 - val1)
x1, y1, z1 = p1
x2, y2, z2 = p2
x = x1 + mu * (x2 - x1)
y = y1 + mu * (y2 - y1)
z = z1 + mu * (z2 - z1)
else:
x, y, z = point
return [x, y, z]
def marching_cubes_naive( def marching_cubes_naive(
volume_data_batch: torch.Tensor, vol_batch: torch.Tensor,
isolevel: Optional[float] = None, isolevel: Optional[float] = None,
spacing: int = 1,
return_local_coords: bool = True, return_local_coords: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
""" """
Runs the classic marching cubes algorithm, iterating over Runs the classic marching cubes algorithm, iterating over
the coordinates of the volume_data and using a given isolevel the coordinates of the volume and using a given isolevel
for determining intersected edges of cubes of size `spacing`. for determining intersected edges of cubes.
Returns vertices and faces of the obtained mesh. Returns vertices and faces of the obtained mesh.
This operation is non-differentiable. This operation is non-differentiable.
This is a naive implementation, and is not optimized for efficiency.
Args: Args:
volume_data_batch: a Tensor of size (N, D, H, W) corresponding to vol_batch: a Tensor of size (N, D, H, W) corresponding to
a batch of 3D scalar fields a batch of 3D scalar fields
isolevel: the isosurface value to use as the threshold to determine isolevel: the isosurface value to use as the threshold to determine
whether points are within a volume. If None, then the average of the whether points are within a volume. If None, then the average of the
maximum and minimum value of the scalar field will be used. maximum and minimum value of the scalar field will be used.
spacing: an integer specifying the cube size to use
return_local_coords: bool. If True the output vertices will be in local coordinates in return_local_coords: bool. If True the output vertices will be in local coordinates in
the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range
[0, W-1] x [0, H-1] x [0, D-1] [0, W-1] x [0, H-1] x [0, D-1]
Returns: Returns:
verts: [(V_0, 3), (V_1, 3), ...] List of N FloatTensors of vertices. verts: [{V_0}, {V_1}, ...] List of N sets of vertices of shape (|V_i|, 3) in FloatTensor
faces: [(F_0, 3), (F_1, 3), ...] List of N LongTensors of faces. faces: [{F_0}, {F_1}, ...] List of N sets of faces of shape (|F_i|, 3) in LongTensors
""" """
volume_data_batch = volume_data_batch.detach().cpu()
batched_verts, batched_faces = [], [] batched_verts, batched_faces = [], []
D, H, W = volume_data_batch.shape[1:] D, H, W = vol_batch.shape[1:]
volume_size_xyz = volume_data_batch.new_tensor([W, H, D])[None]
if return_local_coords: # each edge is represented with its two endpoints (represented with global id)
# Convert from local coordinates in the range [-1, 1] range to for i in range(len(vol_batch)):
# world coordinates in the range [0, D-1], [0, H-1], [0, W-1] vol = vol_batch[i]
local_to_world_transform = Translate( thresh = ((vol.max() + vol.min()) / 2).item() if isolevel is None else isolevel
x=+1.0, y=+1.0, z=+1.0, device=volume_data_batch.device vpair_to_edge = {} # maps from tuple of edge endpoints to edge_id
).scale((volume_size_xyz - 1) * spacing * 0.5) edge_id_to_v = {} # maps from edge ID to vertex position
# Perform the inverse to go from world to local uniq_edge_id = {} # unique edge IDs
world_to_local_transform = local_to_world_transform.inverse() verts = [] # store vertex positions
faces = [] # store face indices
# enumerate each cell in the 3d grid
for z in range(0, D - 1):
for y in range(0, H - 1):
for x in range(0, W - 1):
cube = Cube((x, y, z), vol, thresh)
edge_indices = FACE_TABLE[cube.cube_index]
# cube is entirely in/out of the surface
if len(edge_indices) == 0:
continue
# gather mesh vertices/faces by processing each cube
interp_points = [[0.0, 0.0, 0.0]] * 12
# triangle vertex IDs and positions
tri = []
ps = []
for i, edge in enumerate(edge_indices):
interp_points[edge] = cube.vert_interp(thresh, edge, vol)
# Bind interpolated vertex with a global edge_id, which
# is represented by a pair of vertex ids (v1_id, v2_id)
# corresponding to a local edge.
(v1_id, v2_id) = cube.get_vpair_from_edge(edge, W, H)
edge_id = vpair_to_edge.setdefault(
(v1_id, v2_id), len(vpair_to_edge)
)
tri.append(edge_id)
ps.append(interp_points[edge])
# when the isolevel are the same as the edge endpoints, the interploated
# vertices can share the same values, and lead to degenerate triangles.
if (
(i + 1) % 3 == 0
and ps[0] != ps[1]
and ps[1] != ps[2]
and ps[2] != ps[0]
):
for j, edge_id in enumerate(tri):
edge_id_to_v[edge_id] = ps[j]
if edge_id not in uniq_edge_id:
uniq_edge_id[edge_id] = len(verts)
verts.append(edge_id_to_v[edge_id])
faces.append([uniq_edge_id[tri[j]] for j in range(3)])
tri = []
ps = []
for i in range(len(volume_data_batch)):
volume_data = volume_data_batch[i]
curr_isolevel = (
((volume_data.max() + volume_data.min()) / 2).item()
if isolevel is None
else isolevel
)
edge_vertices_to_index = {}
vertex_coords_to_index = {}
verts, faces = [], []
# Use length - spacing for the bounds since we are using
# cubes of size spacing, with the lowest x,y,z values
# (bottom front left)
for x in range(0, W - spacing, spacing):
for y in range(0, H - spacing, spacing):
for z in range(0, D - spacing, spacing):
cube = Cube((x, y, z), spacing)
new_verts, new_faces = polygonise(
cube,
curr_isolevel,
volume_data,
edge_vertices_to_index,
vertex_coords_to_index,
)
verts.extend(new_verts)
faces.extend(new_faces)
if len(faces) > 0 and len(verts) > 0: if len(faces) > 0 and len(verts) > 0:
verts = torch.tensor(verts, dtype=torch.float32) verts = torch.tensor(verts, dtype=vol.dtype)
# Convert vertices from world to local coords # Convert from world coordinates ([0, D-1], [0, H-1], [0, W-1]) to
# local coordinates in the range [-1, 1]
if return_local_coords: if return_local_coords:
verts = world_to_local_transform.transform_points(verts[None, ...]) verts = (
verts = verts.squeeze() Translate(x=+1.0, y=+1.0, z=+1.0, device=vol_batch.device)
.scale((vol_batch.new_tensor([W, H, D])[None] - 1) * 0.5)
.inverse()
).transform_points(verts[None])[0]
batched_verts.append(verts) batched_verts.append(verts)
batched_faces.append(torch.tensor(faces, dtype=torch.int64)) batched_faces.append(torch.tensor(faces, dtype=torch.int64))
return batched_verts, batched_faces
def polygonise(
cube: Cube,
isolevel: float,
volume_data: torch.Tensor,
edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],
vertex_coords_to_index: Dict[Tuple[float, float, float], int],
) -> Tuple[list, list]:
"""
Runs the classic marching cubes algorithm for one Cube in the volume.
Returns the vertices and faces for the given cube.
Args:
cube: a Cube indicating the cube being examined for edges that intersect
the volume data.
isolevel: the isosurface value to use as the threshold to determine
whether points are within a volume.
volume_data: a Tensor of shape (D, H, W) corresponding to
a 3D scalar field
edge_vertices_to_index: A dictionary which maps an edge's two coordinates
to the index of its interpolated point, if that interpolated point
has already been used by a previous point
vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding
index of that vertex, if that point has already been marked as a vertex.
Returns:
verts: List of triangle vertices for the given cube in the volume
faces: List of triangle faces for the given cube in the volume
"""
num_existing_verts = max(edge_vertices_to_index.values(), default=-1) + 1
verts, faces = [], []
cube_index = cube.get_index(volume_data, isolevel)
edges = EDGE_TABLE[cube_index]
edge_indices = _get_edge_indices(edges)
if len(edge_indices) == 0:
return [], []
new_verts, edge_index_to_point_index = _calculate_interp_vertices(
edge_indices,
volume_data,
cube,
isolevel,
edge_vertices_to_index,
vertex_coords_to_index,
num_existing_verts,
)
# Create faces
face_triangles = FACE_TABLE[cube_index]
for i in range(0, len(face_triangles), 3):
tri1 = edge_index_to_point_index[face_triangles[i]]
tri2 = edge_index_to_point_index[face_triangles[i + 1]]
tri3 = edge_index_to_point_index[face_triangles[i + 2]]
if tri1 != tri2 and tri2 != tri3 and tri1 != tri3:
faces.append([tri1, tri2, tri3])
verts += new_verts
return verts, faces
def _get_edge_indices(edges: int) -> List[int]:
"""
Finds which edge numbers are intersected given the bit representation
detailed in marching_cubes_data.EDGE_TABLE.
Args:
edges: an integer corresponding to the value at cube_index
from the EDGE_TABLE in marching_cubes_data.py
Returns:
edge_indices: A list of edge indices
"""
if edges == 0:
return []
edge_indices = []
for i in range(12):
if edges & (2**i):
edge_indices.append(i)
return edge_indices
def _calculate_interp_vertices(
edge_indices: List[int],
volume_data: torch.Tensor,
cube: Cube,
isolevel: float,
edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],
vertex_coords_to_index: Dict[Tuple[float, float, float], int],
num_existing_verts: int,
) -> Tuple[List, Dict[int, int]]:
"""
Finds the interpolated vertices for the intersected edges, either referencing
previous calculations or newly calculating and storing the new interpolated
points.
Args:
edge_indices: the numbers of the edges which are intersected. See the
Cube class for more detail on the edge numbering convention.
volume_data: a Tensor of size (D, H, W) corresponding to
a 3D scalar field
cube: a Cube indicating the cube being examined for edges that intersect
the volume
isolevel: the isosurface value to use as the threshold to determine
whether points are within a volume.
edge_vertices_to_index: A dictionary which maps an edge's two coordinates
to the index of its interpolated point, if that interpolated point
has already been used by a previous point
vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding
index of that vertex, if that point has already been marked as a vertex.
num_existing_verts: the number of vertices that have been found in previous
calls to polygonise for the given volume_data in the above function, marching_cubes.
This is equal to the 1 + the maximum value in edge_vertices_to_index.
Returns:
interp_points: a list of new interpolated points
edge_index_to_point_index: a dictionary mapping an edge number to the index in the
marching cubes' vertices list of the interpolated point on that edge. To be precise,
it refers to the index within the vertices list after interp_points
has been appended to the verts list constructed in the marching_cubes_naive
function.
"""
interp_points = []
edge_index_to_point_index = {}
for edge_index in edge_indices:
v1, v2 = EDGE_TO_VERTICES[edge_index]
point1, point2 = cube.vertices[v1], cube.vertices[v2]
p_tuple1, p_tuple2 = tuple(point1.tolist()), tuple(point2.tolist())
if (p_tuple1, p_tuple2) in edge_vertices_to_index:
edge_index_to_point_index[edge_index] = edge_vertices_to_index[
(p_tuple1, p_tuple2)
]
else: else:
val1, val2 = _get_value(point1, volume_data), _get_value( batched_verts.append([])
point2, volume_data batched_faces.append([])
) return batched_verts, batched_faces
point = None
if abs(isolevel - val1) < EPS:
point = point1
if abs(isolevel - val2) < EPS:
point = point2
if abs(val1 - val2) < EPS:
point = point1
if point is None:
mu = (isolevel - val1) / (val2 - val1)
x1, y1, z1 = point1
x2, y2, z2 = point2
x = x1 + mu * (x2 - x1)
y = y1 + mu * (y2 - y1)
z = z1 + mu * (z2 - z1)
else:
x, y, z = point
x, y, z = x.item(), y.item(), z.item() # for dictionary keys
vert_index = None
if (x, y, z) in vertex_coords_to_index:
vert_index = vertex_coords_to_index[(x, y, z)]
else:
vert_index = num_existing_verts + len(interp_points)
interp_points.append([x, y, z])
vertex_coords_to_index[(x, y, z)] = vert_index
edge_vertices_to_index[(p_tuple1, p_tuple2)] = vert_index
edge_index_to_point_index[edge_index] = vert_index
return interp_points, edge_index_to_point_index
def _get_value(point: Tuple[int, int, int], volume_data: torch.Tensor) -> float:
"""
Gets the value at a given coordinate point in the scalar field.
Args:
point: data of shape (3) corresponding to an xyz coordinate.
volume_data: a Tensor of size (D, H, W) corresponding to
a 3D scalar field
Returns:
data: scalar value in the volume at the given point
"""
x, y, z = point
# pyre-fixme[7]: Expected `float` but got `Tensor`.
return volume_data[z][y][x]

View File

@ -4,284 +4,21 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# A length 256 list which maps a cubeindex to a number
# with the intersected edges' bits set to 1.
# Each cubeindex corresponds to a given cube configuration, where
# it is composed of a bitstring where the 0th bit is flipped if vertex 0
# is below the isosurface (i.e. 0x01), for each of the 8 vertices.
EDGE_TABLE = [
0x0,
0x109,
0x203,
0x30A,
0x406,
0x50F,
0x605,
0x70C,
0x80C,
0x905,
0xA0F,
0xB06,
0xC0A,
0xD03,
0xE09,
0xF00,
0x190,
0x99,
0x393,
0x29A,
0x596,
0x49F,
0x795,
0x69C,
0x99C,
0x895,
0xB9F,
0xA96,
0xD9A,
0xC93,
0xF99,
0xE90,
0x230,
0x339,
0x33,
0x13A,
0x636,
0x73F,
0x435,
0x53C,
0xA3C,
0xB35,
0x83F,
0x936,
0xE3A,
0xF33,
0xC39,
0xD30,
0x3A0,
0x2A9,
0x1A3,
0xAA,
0x7A6,
0x6AF,
0x5A5,
0x4AC,
0xBAC,
0xAA5,
0x9AF,
0x8A6,
0xFAA,
0xEA3,
0xDA9,
0xCA0,
0x460,
0x569,
0x663,
0x76A,
0x66,
0x16F,
0x265,
0x36C,
0xC6C,
0xD65,
0xE6F,
0xF66,
0x86A,
0x963,
0xA69,
0xB60,
0x5F0,
0x4F9,
0x7F3,
0x6FA,
0x1F6,
0xFF,
0x3F5,
0x2FC,
0xDFC,
0xCF5,
0xFFF,
0xEF6,
0x9FA,
0x8F3,
0xBF9,
0xAF0,
0x650,
0x759,
0x453,
0x55A,
0x256,
0x35F,
0x55,
0x15C,
0xE5C,
0xF55,
0xC5F,
0xD56,
0xA5A,
0xB53,
0x859,
0x950,
0x7C0,
0x6C9,
0x5C3,
0x4CA,
0x3C6,
0x2CF,
0x1C5,
0xCC,
0xFCC,
0xEC5,
0xDCF,
0xCC6,
0xBCA,
0xAC3,
0x9C9,
0x8C0,
0x8C0,
0x9C9,
0xAC3,
0xBCA,
0xCC6,
0xDCF,
0xEC5,
0xFCC,
0xCC,
0x1C5,
0x2CF,
0x3C6,
0x4CA,
0x5C3,
0x6C9,
0x7C0,
0x950,
0x859,
0xB53,
0xA5A,
0xD56,
0xC5F,
0xF55,
0xE5C,
0x15C,
0x55,
0x35F,
0x256,
0x55A,
0x453,
0x759,
0x650,
0xAF0,
0xBF9,
0x8F3,
0x9FA,
0xEF6,
0xFFF,
0xCF5,
0xDFC,
0x2FC,
0x3F5,
0xFF,
0x1F6,
0x6FA,
0x7F3,
0x4F9,
0x5F0,
0xB60,
0xA69,
0x963,
0x86A,
0xF66,
0xE6F,
0xD65,
0xC6C,
0x36C,
0x265,
0x16F,
0x66,
0x76A,
0x663,
0x569,
0x460,
0xCA0,
0xDA9,
0xEA3,
0xFAA,
0x8A6,
0x9AF,
0xAA5,
0xBAC,
0x4AC,
0x5A5,
0x6AF,
0x7A6,
0xAA,
0x1A3,
0x2A9,
0x3A0,
0xD30,
0xC39,
0xF33,
0xE3A,
0x936,
0x83F,
0xB35,
0xA3C,
0x53C,
0x435,
0x73F,
0x636,
0x13A,
0x33,
0x339,
0x230,
0xE90,
0xF99,
0xC93,
0xD9A,
0xA96,
0xB9F,
0x895,
0x99C,
0x69C,
0x795,
0x49F,
0x596,
0x29A,
0x393,
0x99,
0x190,
0xF00,
0xE09,
0xD03,
0xC0A,
0xB06,
0xA0F,
0x905,
0x80C,
0x70C,
0x605,
0x50F,
0x406,
0x30A,
0x203,
0x109,
0x0,
]
# Maps each edge (by index) to the corresponding cube vertices # Maps each edge (by index) to the corresponding cube vertices
EDGE_TO_VERTICES = [ EDGE_TO_VERTICES = [
[0, 1], [0, 1],
[1, 2],
[3, 2],
[0, 3],
[4, 5],
[5, 6],
[7, 6],
[4, 7],
[0, 4],
[1, 5], [1, 5],
[2, 6], [4, 5],
[0, 4],
[2, 3],
[3, 7], [3, 7],
[6, 7],
[2, 6],
[0, 2],
[1, 3],
[5, 7],
[4, 6],
] ]
# A list of lists mapping a cube_index (a given configuration) # A list of lists mapping a cube_index (a given configuration)
@ -547,3 +284,6 @@ FACE_TABLE = [
[0, 3, 8], [0, 3, 8],
[], [],
] ]
# mapping from 0-7 to v0-v7 in cube.vertices
INDEX = [0, 1, 5, 4, 2, 3, 7, 6]

View File

@ -32,8 +32,8 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D) volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
expected_verts = torch.tensor([]) expected_verts = torch.tensor([[]])
expected_faces = torch.tensor([], dtype=torch.int64) expected_faces = torch.tensor([[]], dtype=torch.int64)
self.assertClose(verts, expected_verts) self.assertClose(verts, expected_verts)
self.assertClose(faces, expected_faces) self.assertClose(faces, expected_faces)
@ -42,16 +42,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
volume_data[0, 0, 0, 0] = 0 volume_data[0, 0, 0, 0] = 0
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W) volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False) verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[0.5, 0, 0], [0.5, 0, 0],
[0, 0, 0.5],
[0, 0.5, 0], [0, 0.5, 0],
[0, 0, 0.5],
] ]
) )
expected_faces = torch.tensor([[1, 2, 0]]) expected_faces = torch.tensor([[0, 1, 2]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -69,12 +68,12 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[1.0000, 0.0000, 0.5000], [1.0000, 0.0000, 0.5000],
[0.0000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.5000], [0.0000, 0.0000, 0.5000],
[1.0000, 0.5000, 0.0000], [1.0000, 0.5000, 0.0000],
[0.0000, 0.5000, 0.0000],
] ]
) )
expected_faces = torch.tensor([[1, 2, 0], [3, 2, 1]]) expected_faces = torch.tensor([[0, 1, 2], [3, 1, 0]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -92,15 +91,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[0.5000, 0.0000, 0.0000], [1.0000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.5000],
[1.0000, 1.0000, 0.5000], [1.0000, 1.0000, 0.5000],
[0.5000, 1.0000, 0.0000], [0.5000, 1.0000, 0.0000],
[1.0000, 0.5000, 0.0000], [0.5000, 0.0000, 0.0000],
[0.0000, 0.5000, 0.0000], [0.0000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.5000],
] ]
) )
expected_faces = torch.tensor([[0, 1, 5], [4, 3, 2]]) expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -119,14 +118,14 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.5000], [0.0000, 0.0000, 0.5000],
[1.0000, 0.5000, 0.0000],
[0.5000, 0.0000, 0.0000],
[0.0000, 0.5000, 1.0000], [0.0000, 0.5000, 1.0000],
[1.0000, 0.5000, 1.0000], [1.0000, 0.5000, 1.0000],
[1.0000, 0.5000, 0.0000],
] ]
) )
expected_faces = torch.tensor([[0, 2, 1], [0, 4, 2], [4, 3, 2]]) expected_faces = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 4, 1]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -143,15 +142,14 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[0.0000, 0.5000, 1.0000],
[1.0000, 0.5000, 1.0000],
[1.0000, 0.5000, 0.0000], [1.0000, 0.5000, 0.0000],
[0.0000, 0.5000, 0.0000], [0.0000, 0.5000, 0.0000],
[1.0000, 0.5000, 1.0000],
[0.0000, 0.5000, 1.0000],
] ]
) )
expected_faces = torch.tensor([[1, 0, 2], [2, 0, 3]]) expected_faces = torch.tensor([[0, 1, 2], [2, 1, 3]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -171,17 +169,17 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.5000],
[0.5000, 1.0000, 0.0000], [0.5000, 1.0000, 0.0000],
[0.0000, 1.0000, 0.5000], [0.0000, 1.0000, 0.5000],
[0.0000, 0.5000, 0.0000],
[1.0000, 0.5000, 0.0000],
[0.5000, 0.0000, 0.0000],
[0.0000, 0.5000, 1.0000], [0.0000, 0.5000, 1.0000],
[1.0000, 0.5000, 1.0000], [1.0000, 0.5000, 1.0000],
[1.0000, 0.5000, 0.0000], [0.0000, 0.0000, 0.5000],
[0.0000, 0.5000, 0.0000],
] ]
) )
expected_faces = torch.tensor([[2, 7, 3], [0, 6, 1], [6, 4, 1], [6, 5, 4]]) expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [3, 5, 6], [5, 4, 7]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -202,22 +200,22 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[0.5000, 0.0000, 1.0000],
[1.0000, 0.0000, 0.5000],
[0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.5000],
[0.5000, 1.0000, 1.0000], [0.5000, 1.0000, 1.0000],
[1.0000, 1.0000, 0.5000],
[0.5000, 1.0000, 0.0000],
[0.0000, 1.0000, 0.5000],
[0.0000, 0.5000, 1.0000], [0.0000, 0.5000, 1.0000],
[0.0000, 1.0000, 0.5000],
[1.0000, 0.0000, 0.5000],
[0.5000, 0.0000, 1.0000],
[1.0000, 0.5000, 1.0000], [1.0000, 0.5000, 1.0000],
[1.0000, 0.5000, 0.0000], [0.5000, 0.0000, 0.0000],
[0.0000, 0.5000, 0.0000], [0.0000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.5000],
[0.5000, 1.0000, 0.0000],
[1.0000, 0.5000, 0.0000],
[1.0000, 1.0000, 0.5000],
] ]
) )
expected_faces = torch.tensor([[0, 1, 9], [4, 7, 8], [2, 3, 11], [5, 10, 6]]) expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -238,15 +236,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[1.0000, 0.0000, 0.5000],
[0.5000, 0.0000, 0.0000],
[0.5000, 1.0000, 1.0000],
[0.0000, 1.0000, 0.5000],
[1.0000, 0.5000, 1.0000], [1.0000, 0.5000, 1.0000],
[0.0000, 1.0000, 0.5000],
[0.5000, 1.0000, 1.0000],
[1.0000, 0.0000, 0.5000],
[0.0000, 0.5000, 0.0000], [0.0000, 0.5000, 0.0000],
[0.5000, 0.0000, 0.0000],
] ]
) )
expected_faces = torch.tensor([[2, 3, 5], [4, 2, 5], [4, 5, 1], [4, 1, 0]]) expected_faces = torch.tensor([[0, 1, 2], [3, 1, 0], [3, 4, 1], [3, 5, 4]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -269,13 +267,13 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
[ [
[0.5000, 0.0000, 0.0000], [0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.5000], [0.0000, 0.0000, 0.5000],
[0.5000, 1.0000, 1.0000],
[0.0000, 1.0000, 0.5000], [0.0000, 1.0000, 0.5000],
[1.0000, 0.5000, 1.0000], [1.0000, 0.5000, 1.0000],
[1.0000, 0.5000, 0.0000], [1.0000, 0.5000, 0.0000],
[0.5000, 1.0000, 1.0000],
] ]
) )
expected_faces = torch.tensor([[0, 5, 4], [0, 4, 3], [0, 3, 1], [3, 4, 2]]) expected_faces = torch.tensor([[0, 1, 2], [0, 2, 3], [0, 3, 4], [5, 3, 2]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -295,15 +293,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[0.5000, 0.0000, 0.0000], [0.5000, 0.0000, 0.0000],
[0.0000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.5000], [0.0000, 0.0000, 0.5000],
[0.5000, 1.0000, 1.0000],
[1.0000, 1.0000, 0.5000], [1.0000, 1.0000, 0.5000],
[1.0000, 0.5000, 1.0000], [1.0000, 0.5000, 1.0000],
[0.0000, 0.5000, 0.0000], [0.5000, 1.0000, 1.0000],
] ]
) )
expected_faces = torch.tensor([[4, 3, 2], [0, 1, 5]]) expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -324,16 +322,16 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[1.0000, 0.0000, 0.5000], [1.0000, 0.0000, 0.5000],
[0.0000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.5000], [0.0000, 0.0000, 0.5000],
[0.5000, 1.0000, 1.0000], [1.0000, 0.5000, 0.0000],
[1.0000, 1.0000, 0.5000], [1.0000, 1.0000, 0.5000],
[1.0000, 0.5000, 1.0000], [1.0000, 0.5000, 1.0000],
[1.0000, 0.5000, 0.0000], [0.5000, 1.0000, 1.0000],
[0.0000, 0.5000, 0.0000],
] ]
) )
expected_faces = torch.tensor([[5, 1, 6], [5, 0, 1], [4, 3, 2]]) expected_faces = torch.tensor([[0, 1, 2], [0, 3, 1], [4, 5, 6]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -354,18 +352,18 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[1.0000, 0.0000, 0.5000], [1.0000, 0.0000, 0.5000],
[1.0000, 0.5000, 0.0000],
[0.5000, 0.0000, 0.0000], [0.5000, 0.0000, 0.0000],
[0.5000, 1.0000, 1.0000],
[1.0000, 1.0000, 0.5000], [1.0000, 1.0000, 0.5000],
[1.0000, 0.5000, 1.0000],
[0.5000, 1.0000, 1.0000],
[0.0000, 0.5000, 0.0000],
[0.5000, 1.0000, 0.0000], [0.5000, 1.0000, 0.0000],
[0.0000, 1.0000, 0.5000], [0.0000, 1.0000, 0.5000],
[1.0000, 0.5000, 1.0000],
[1.0000, 0.5000, 0.0000],
[0.0000, 0.5000, 0.0000],
] ]
) )
expected_faces = torch.tensor([[6, 3, 2], [7, 0, 1], [5, 4, 8]]) expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -386,18 +384,18 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[0.5000, 0.0000, 1.0000],
[1.0000, 0.0000, 0.5000], [1.0000, 0.0000, 0.5000],
[0.5000, 0.0000, 0.0000], [0.5000, 0.0000, 1.0000],
[0.0000, 0.0000, 0.5000],
[0.5000, 1.0000, 1.0000],
[1.0000, 1.0000, 0.5000], [1.0000, 1.0000, 0.5000],
[0.5000, 1.0000, 1.0000],
[0.0000, 0.0000, 0.5000],
[0.5000, 0.0000, 0.0000],
[0.5000, 1.0000, 0.0000], [0.5000, 1.0000, 0.0000],
[0.0000, 1.0000, 0.5000], [0.0000, 1.0000, 0.5000],
] ]
) )
expected_faces = torch.tensor([[3, 6, 2], [3, 7, 6], [1, 5, 0], [5, 4, 0]]) expected_faces = torch.tensor([[0, 1, 2], [2, 1, 3], [4, 5, 6], [4, 6, 7]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -418,16 +416,16 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[1.0000, 0.0000, 0.5000],
[0.5000, 0.0000, 0.0000], [0.5000, 0.0000, 0.0000],
[0.5000, 1.0000, 1.0000],
[1.0000, 1.0000, 0.5000],
[0.0000, 0.5000, 1.0000],
[0.0000, 0.5000, 0.0000], [0.0000, 0.5000, 0.0000],
[0.0000, 0.5000, 1.0000],
[1.0000, 1.0000, 0.5000],
[1.0000, 0.0000, 0.5000],
[0.5000, 1.0000, 1.0000],
] ]
) )
expected_faces = torch.tensor([[1, 0, 3], [1, 3, 4], [1, 4, 5], [2, 4, 3]]) expected_faces = torch.tensor([[0, 1, 2], [0, 2, 3], [0, 3, 4], [3, 2, 5]])
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -447,27 +445,26 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[0.5, 1, 1], [1.0000, 0.5000, 1.0000],
[1, 1, 0.5], [1.0000, 1.0000, 0.5000],
[1, 0.5, 1], [0.5000, 1.0000, 1.0000],
[1, 1, 1.5], [1.5000, 1.0000, 1.0000],
[1, 1.5, 1], [1.0000, 1.5000, 1.0000],
[1.5, 1, 1], [1.0000, 1.0000, 1.5000],
] ]
) )
expected_faces = torch.tensor( expected_faces = torch.tensor(
[ [
[2, 0, 1], [0, 1, 2],
[2, 3, 0], [1, 0, 3],
[0, 4, 1], [1, 4, 2],
[3, 4, 0], [1, 3, 4],
[5, 2, 1], [0, 2, 5],
[3, 2, 5], [3, 0, 5],
[5, 1, 4], [2, 4, 5],
[3, 5, 4], [3, 5, 4],
] ]
) )
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -492,76 +489,76 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[0.9000, 1.0000, 1.0000],
[1.0000, 1.0000, 0.9000],
[1.0000, 0.9000, 1.0000], [1.0000, 0.9000, 1.0000],
[0.9000, 1.0000, 2.0000], [1.0000, 1.0000, 0.9000],
[1.0000, 0.9000, 2.0000], [0.9000, 1.0000, 1.0000],
[1.0000, 1.0000, 2.1000],
[0.9000, 2.0000, 1.0000],
[1.0000, 2.0000, 0.9000],
[0.9000, 2.0000, 2.0000],
[1.0000, 2.0000, 2.1000],
[1.0000, 2.1000, 1.0000],
[1.0000, 2.1000, 2.0000],
[2.0000, 1.0000, 0.9000],
[2.0000, 0.9000, 1.0000], [2.0000, 0.9000, 1.0000],
[2.0000, 0.9000, 2.0000], [2.0000, 1.0000, 0.9000],
[2.0000, 1.0000, 2.1000],
[2.0000, 2.0000, 0.9000],
[2.0000, 2.0000, 2.1000],
[2.0000, 2.1000, 1.0000],
[2.0000, 2.1000, 2.0000],
[2.1000, 1.0000, 1.0000], [2.1000, 1.0000, 1.0000],
[2.1000, 1.0000, 2.0000], [1.0000, 2.0000, 0.9000],
[0.9000, 2.0000, 1.0000],
[2.0000, 2.0000, 0.9000],
[2.1000, 2.0000, 1.0000], [2.1000, 2.0000, 1.0000],
[1.0000, 2.1000, 1.0000],
[2.0000, 2.1000, 1.0000],
[1.0000, 0.9000, 2.0000],
[0.9000, 1.0000, 2.0000],
[2.0000, 0.9000, 2.0000],
[2.1000, 1.0000, 2.0000],
[0.9000, 2.0000, 2.0000],
[2.1000, 2.0000, 2.0000], [2.1000, 2.0000, 2.0000],
[1.0000, 2.1000, 2.0000],
[2.0000, 2.1000, 2.0000],
[1.0000, 1.0000, 2.1000],
[2.0000, 1.0000, 2.1000],
[1.0000, 2.0000, 2.1000],
[2.0000, 2.0000, 2.1000],
] ]
) )
expected_faces = torch.tensor( expected_faces = torch.tensor(
[ [
[2, 0, 1], [0, 1, 2],
[2, 4, 3], [0, 3, 4],
[0, 2, 3], [1, 0, 4],
[4, 5, 3], [4, 3, 5],
[0, 6, 7], [1, 6, 7],
[1, 0, 7], [2, 1, 7],
[3, 8, 0], [4, 8, 1],
[8, 6, 0], [1, 8, 6],
[5, 9, 8], [8, 4, 5],
[3, 5, 8], [9, 8, 5],
[6, 10, 7], [6, 10, 7],
[11, 10, 6], [6, 8, 11],
[8, 11, 6], [10, 6, 11],
[9, 11, 8], [8, 9, 11],
[13, 2, 1], [12, 0, 2],
[12, 13, 1], [13, 12, 2],
[14, 4, 13], [3, 0, 14],
[13, 4, 2], [14, 0, 12],
[4, 14, 15], [15, 5, 3],
[5, 4, 15], [14, 15, 3],
[12, 1, 16], [2, 7, 13],
[1, 7, 16], [7, 16, 13],
[15, 17, 5], [5, 15, 9],
[5, 17, 9], [9, 15, 17],
[16, 7, 10], [10, 18, 16],
[18, 16, 10], [7, 10, 16],
[19, 18, 11], [11, 19, 10],
[18, 10, 11], [19, 18, 10],
[9, 17, 19], [9, 17, 19],
[11, 9, 19], [11, 9, 19],
[20, 13, 12], [12, 13, 20],
[20, 21, 14], [14, 12, 20],
[13, 20, 14], [21, 14, 20],
[15, 14, 21], [15, 14, 21],
[22, 20, 12], [13, 16, 22],
[16, 22, 12], [20, 13, 22],
[21, 20, 23], [21, 20, 23],
[23, 20, 22], [20, 22, 23],
[17, 15, 21], [17, 15, 21],
[23, 17, 21], [23, 17, 21],
[22, 16, 18], [16, 18, 22],
[23, 22, 18], [23, 22, 18],
[19, 23, 18], [19, 23, 18],
[17, 23, 19], [17, 23, 19],
@ -569,6 +566,7 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
) )
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=True) verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=True)
expected_verts = convert_to_local(expected_verts, 5) expected_verts = convert_to_local(expected_verts, 5)
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
@ -592,34 +590,49 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
expected_verts = torch.tensor( expected_verts = torch.tensor(
[ [
[2.0, 1.0, 1.0],
[2.0, 2.0, 1.0],
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0],
[1.0, 1.0, 2.0],
[1.0, 2.0, 1.0], [1.0, 2.0, 1.0],
[2.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[2.0, 1.0, 2.0],
[1.0, 1.0, 2.0],
[1.0, 1.0, 1.0],
[1.0, 2.0, 1.0],
[1.0, 1.0, 2.0],
[1.0, 2.0, 2.0], [1.0, 2.0, 2.0],
[2.0, 1.0, 1.0], [2.0, 1.0, 1.0],
[2.0, 1.0, 2.0], [2.0, 1.0, 2.0],
[2.0, 2.0, 1.0], [2.0, 2.0, 1.0],
[2.0, 2.0, 2.0], [2.0, 2.0, 2.0],
[2.0, 2.0, 1.0],
[2.0, 2.0, 2.0],
[1.0, 2.0, 1.0],
[1.0, 2.0, 2.0],
[2.0, 1.0, 2.0],
[1.0, 1.0, 2.0],
[2.0, 2.0, 2.0],
[1.0, 2.0, 2.0],
] ]
) )
expected_faces = torch.tensor( expected_faces = torch.tensor(
[ [
[1, 3, 0], [0, 1, 2],
[3, 2, 0], [2, 1, 3],
[5, 1, 4], [4, 5, 6],
[4, 1, 0], [6, 5, 7],
[4, 0, 6], [8, 9, 10],
[0, 2, 6], [9, 11, 10],
[5, 7, 1], [12, 13, 14],
[1, 7, 3], [14, 13, 15],
[7, 6, 3], [16, 17, 18],
[6, 2, 3], [17, 19, 18],
[5, 4, 7], [20, 21, 22],
[7, 4, 6], [21, 23, 22],
] ]
) )
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -651,8 +664,8 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
filename = os.path.join(DATA_DIR, data_filename) filename = os.path.join(DATA_DIR, data_filename)
with open(filename, "rb") as file: with open(filename, "rb") as file:
verts_and_faces = pickle.load(file) verts_and_faces = pickle.load(file)
expected_verts = verts_and_faces["verts"].squeeze() expected_verts = verts_and_faces["verts"]
expected_faces = verts_and_faces["faces"].squeeze() expected_faces = verts_and_faces["faces"]
self.assertClose(verts[0], expected_verts) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces) self.assertClose(faces[0], expected_faces)
@ -689,8 +702,8 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
expected_verts = verts_and_faces["verts"] expected_verts = verts_and_faces["verts"]
expected_faces = verts_and_faces["faces"] expected_faces = verts_and_faces["faces"]
self.assertClose(verts[0], expected_verts[0]) self.assertClose(verts[0], expected_verts)
self.assertClose(faces[0], expected_faces[0]) self.assertClose(faces[0], expected_faces)
def test_cube_surface_area(self): def test_cube_surface_area(self):
if USE_SCIKIT: if USE_SCIKIT:
@ -760,16 +773,26 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
self.assertClose(surf, surf_sci) self.assertClose(surf, surf_sci)
def test_ball_example(self):
N = 15
axis_tensor = torch.arange(0, N)
X, Y, Z = torch.meshgrid(axis_tensor, axis_tensor, axis_tensor, indexing="ij")
u = (X - 15) ** 2 + (Y - 15) ** 2 + (Z - 15) ** 2 - 8**2
u = u[None].float()
verts, faces = marching_cubes_naive(u, 0, return_local_coords=False)
@staticmethod @staticmethod
def marching_cubes_with_init(batch_size: int, V: int): def marching_cubes_with_init(algo_type: str, batch_size: int, V: int):
device = torch.device("cuda:0") device = torch.device("cuda:0")
volume_data = torch.rand( volume_data = torch.rand(
(batch_size, V, V, V), dtype=torch.float32, device=device (batch_size, V, V, V), dtype=torch.float32, device=device
) )
torch.cuda.synchronize() algo_table = {
"naive": marching_cubes_naive,
}
def convert(): def convert():
marching_cubes_naive(volume_data, return_local_coords=False) algo_table[algo_type](volume_data, return_local_coords=False)
torch.cuda.synchronize() torch.cuda.synchronize()
return convert return convert