mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Python marching cubes improvements
Summary: Overhaul of marching_cubes_naive for better performance and to avoid relying on unstable hashing. In particular, instead of hashing vertex positions, we index each interpolated vertex with its corresponding edge in the 3d grid. Reviewed By: kjchalup Differential Revision: D39419642 fbshipit-source-id: b5fede3525c545d1d374198928dfb216262f0ec0
This commit is contained in:
parent
6471893f59
commit
850efdf706
@ -4,10 +4,10 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.ops.marching_cubes_data import EDGE_TABLE, EDGE_TO_VERTICES, FACE_TABLE
|
from pytorch3d.ops.marching_cubes_data import EDGE_TO_VERTICES, FACE_TABLE, INDEX
|
||||||
from pytorch3d.transforms import Translate
|
from pytorch3d.transforms import Translate
|
||||||
|
|
||||||
|
|
||||||
@ -15,10 +15,15 @@ EPS = 0.00001
|
|||||||
|
|
||||||
|
|
||||||
class Cube:
|
class Cube:
|
||||||
def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
bfl_v: Tuple[int, int, int],
|
||||||
|
volume: torch.Tensor,
|
||||||
|
isolevel: float,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes a cube given the bottom front left vertex coordinate
|
Initializes a cube given the bottom front left vertex coordinate
|
||||||
and the cube spacing
|
and computes the cube configuration given vertex values and isolevel.
|
||||||
|
|
||||||
Edge and vertex convention:
|
Edge and vertex convention:
|
||||||
|
|
||||||
@ -31,8 +36,8 @@ class Cube:
|
|||||||
| | | |
|
| | | |
|
||||||
| |e8 |e10|
|
| |e8 |e10|
|
||||||
e11| | | |
|
e11| | | |
|
||||||
| |_________________|___|
|
| |______e0_________|___|
|
||||||
| / v0 e0 | /v1
|
| / v0(bfl_v) | |v1
|
||||||
| / | /
|
| / | /
|
||||||
| /e3 | /e1
|
| /e3 | /e1
|
||||||
|/_____________________|/
|
|/_____________________|/
|
||||||
@ -41,311 +46,182 @@ class Cube:
|
|||||||
Args:
|
Args:
|
||||||
bfl_vertex: a tuple of size 3 corresponding to the bottom front left vertex
|
bfl_vertex: a tuple of size 3 corresponding to the bottom front left vertex
|
||||||
of the cube in (x, y, z) format
|
of the cube in (x, y, z) format
|
||||||
spacing: the length of each edge of the cube
|
volume: the 3D scalar data
|
||||||
|
isolevel: the isosurface value used as a threshold for determining whether a point
|
||||||
|
is inside/outside the volume
|
||||||
"""
|
"""
|
||||||
# match corner orders to algorithm convention
|
x, y, z = bfl_v
|
||||||
if len(bfl_vertex) != 3:
|
self.x, self.y, self.z = x, y, z
|
||||||
msg = "The vertex {} is size {} instead of size 3".format(
|
self.bfl_v = bfl_v
|
||||||
bfl_vertex, len(bfl_vertex)
|
self.verts = [
|
||||||
)
|
[x + (v & 1), y + (v >> 1 & 1), z + (v >> 2 & 1)] for v in range(8)
|
||||||
raise ValueError(msg)
|
] # vertex position (x, y, z) for v0-v1-v4-v5-v3-v2-v7-v6
|
||||||
|
|
||||||
x, y, z = bfl_vertex
|
# Calculates cube configuration index given values of the cube vertices
|
||||||
self.vertices = torch.tensor(
|
self.cube_index = 0
|
||||||
[
|
for i in range(8):
|
||||||
[x, y, z + spacing],
|
v = self.verts[INDEX[i]]
|
||||||
[x + spacing, y, z + spacing],
|
value = volume[v[2]][v[1]][v[0]]
|
||||||
[x + spacing, y, z],
|
|
||||||
[x, y, z],
|
|
||||||
[x, y + spacing, z + spacing],
|
|
||||||
[x + spacing, y + spacing, z + spacing],
|
|
||||||
[x + spacing, y + spacing, z],
|
|
||||||
[x, y + spacing, z],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_index(self, volume_data: torch.Tensor, isolevel: float) -> int:
|
|
||||||
"""
|
|
||||||
Calculates the cube_index in the range 0-255 to index
|
|
||||||
into EDGE_TABLE and FACE_TABLE
|
|
||||||
Args:
|
|
||||||
volume_data: the 3D scalar data
|
|
||||||
isolevel: the isosurface value used as a threshold
|
|
||||||
for determining whether a point is inside/outside
|
|
||||||
the volume
|
|
||||||
"""
|
|
||||||
cube_index = 0
|
|
||||||
bit = 1
|
|
||||||
for index in range(len(self.vertices)):
|
|
||||||
vertex = self.vertices[index]
|
|
||||||
value = _get_value(vertex, volume_data)
|
|
||||||
if value < isolevel:
|
if value < isolevel:
|
||||||
cube_index |= bit
|
self.cube_index |= 1 << i
|
||||||
bit *= 2
|
|
||||||
return cube_index
|
def get_vpair_from_edge(self, edge: int, W: int, H: int) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Get a tuple of global vertex ID from a local edge ID
|
||||||
|
Global vertex ID is calculated as (x + dx) + (y + dy) * W + (z + dz) * W * H
|
||||||
|
|
||||||
|
Args:
|
||||||
|
edge: local edge ID in the cube
|
||||||
|
bfl_vertex: bottom-front-left coordinate of the cube
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a pair of global vertex ID
|
||||||
|
"""
|
||||||
|
v1, v2 = EDGE_TO_VERTICES[edge] # two end-points on the edge
|
||||||
|
v1_id = self.verts[v1][0] + self.verts[v1][1] * W + self.verts[v1][2] * W * H
|
||||||
|
v2_id = self.verts[v2][0] + self.verts[v2][1] * W + self.verts[v2][2] * W * H
|
||||||
|
return (v1_id, v2_id)
|
||||||
|
|
||||||
|
def vert_interp(
|
||||||
|
self,
|
||||||
|
isolevel: float,
|
||||||
|
edge: int,
|
||||||
|
vol: torch.Tensor,
|
||||||
|
) -> List:
|
||||||
|
"""
|
||||||
|
Linearly interpolate a vertex where an isosurface cuts an edge
|
||||||
|
between the two endpoint vertices, based on their values
|
||||||
|
|
||||||
|
Args:
|
||||||
|
isolevel: the isosurface value to use as the threshold to determine
|
||||||
|
whether points are within a volume.
|
||||||
|
edge: edge (ID) to interpolate
|
||||||
|
cube: current cube vertices
|
||||||
|
vol: 3D scalar field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
interpolated vertex: position of the interpolated vertex on the edge
|
||||||
|
"""
|
||||||
|
v1, v2 = EDGE_TO_VERTICES[edge]
|
||||||
|
p1, p2 = self.verts[v1], self.verts[v2]
|
||||||
|
val1, val2 = (
|
||||||
|
vol[p1[2]][p1[1]][p1[0]],
|
||||||
|
vol[p2[2]][p2[1]][p2[0]],
|
||||||
|
)
|
||||||
|
point = None
|
||||||
|
if abs(isolevel - val1) < EPS:
|
||||||
|
point = p1
|
||||||
|
elif abs(isolevel - val2) < EPS:
|
||||||
|
point = p2
|
||||||
|
elif abs(val1 - val2) < EPS:
|
||||||
|
point = p1
|
||||||
|
|
||||||
|
if point is None:
|
||||||
|
mu = (isolevel - val1) / (val2 - val1)
|
||||||
|
x1, y1, z1 = p1
|
||||||
|
x2, y2, z2 = p2
|
||||||
|
x = x1 + mu * (x2 - x1)
|
||||||
|
y = y1 + mu * (y2 - y1)
|
||||||
|
z = z1 + mu * (z2 - z1)
|
||||||
|
else:
|
||||||
|
x, y, z = point
|
||||||
|
return [x, y, z]
|
||||||
|
|
||||||
|
|
||||||
def marching_cubes_naive(
|
def marching_cubes_naive(
|
||||||
volume_data_batch: torch.Tensor,
|
vol_batch: torch.Tensor,
|
||||||
isolevel: Optional[float] = None,
|
isolevel: Optional[float] = None,
|
||||||
spacing: int = 1,
|
|
||||||
return_local_coords: bool = True,
|
return_local_coords: bool = True,
|
||||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Runs the classic marching cubes algorithm, iterating over
|
Runs the classic marching cubes algorithm, iterating over
|
||||||
the coordinates of the volume_data and using a given isolevel
|
the coordinates of the volume and using a given isolevel
|
||||||
for determining intersected edges of cubes of size `spacing`.
|
for determining intersected edges of cubes.
|
||||||
Returns vertices and faces of the obtained mesh.
|
Returns vertices and faces of the obtained mesh.
|
||||||
This operation is non-differentiable.
|
This operation is non-differentiable.
|
||||||
|
|
||||||
This is a naive implementation, and is not optimized for efficiency.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
volume_data_batch: a Tensor of size (N, D, H, W) corresponding to
|
vol_batch: a Tensor of size (N, D, H, W) corresponding to
|
||||||
a batch of 3D scalar fields
|
a batch of 3D scalar fields
|
||||||
isolevel: the isosurface value to use as the threshold to determine
|
isolevel: the isosurface value to use as the threshold to determine
|
||||||
whether points are within a volume. If None, then the average of the
|
whether points are within a volume. If None, then the average of the
|
||||||
maximum and minimum value of the scalar field will be used.
|
maximum and minimum value of the scalar field will be used.
|
||||||
spacing: an integer specifying the cube size to use
|
|
||||||
return_local_coords: bool. If True the output vertices will be in local coordinates in
|
return_local_coords: bool. If True the output vertices will be in local coordinates in
|
||||||
the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range
|
the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range
|
||||||
[0, W-1] x [0, H-1] x [0, D-1]
|
[0, W-1] x [0, H-1] x [0, D-1]
|
||||||
Returns:
|
Returns:
|
||||||
verts: [(V_0, 3), (V_1, 3), ...] List of N FloatTensors of vertices.
|
verts: [{V_0}, {V_1}, ...] List of N sets of vertices of shape (|V_i|, 3) in FloatTensor
|
||||||
faces: [(F_0, 3), (F_1, 3), ...] List of N LongTensors of faces.
|
faces: [{F_0}, {F_1}, ...] List of N sets of faces of shape (|F_i|, 3) in LongTensors
|
||||||
"""
|
"""
|
||||||
volume_data_batch = volume_data_batch.detach().cpu()
|
|
||||||
batched_verts, batched_faces = [], []
|
batched_verts, batched_faces = [], []
|
||||||
D, H, W = volume_data_batch.shape[1:]
|
D, H, W = vol_batch.shape[1:]
|
||||||
volume_size_xyz = volume_data_batch.new_tensor([W, H, D])[None]
|
|
||||||
|
|
||||||
if return_local_coords:
|
# each edge is represented with its two endpoints (represented with global id)
|
||||||
# Convert from local coordinates in the range [-1, 1] range to
|
for i in range(len(vol_batch)):
|
||||||
# world coordinates in the range [0, D-1], [0, H-1], [0, W-1]
|
vol = vol_batch[i]
|
||||||
local_to_world_transform = Translate(
|
thresh = ((vol.max() + vol.min()) / 2).item() if isolevel is None else isolevel
|
||||||
x=+1.0, y=+1.0, z=+1.0, device=volume_data_batch.device
|
vpair_to_edge = {} # maps from tuple of edge endpoints to edge_id
|
||||||
).scale((volume_size_xyz - 1) * spacing * 0.5)
|
edge_id_to_v = {} # maps from edge ID to vertex position
|
||||||
# Perform the inverse to go from world to local
|
uniq_edge_id = {} # unique edge IDs
|
||||||
world_to_local_transform = local_to_world_transform.inverse()
|
verts = [] # store vertex positions
|
||||||
|
faces = [] # store face indices
|
||||||
|
# enumerate each cell in the 3d grid
|
||||||
|
for z in range(0, D - 1):
|
||||||
|
for y in range(0, H - 1):
|
||||||
|
for x in range(0, W - 1):
|
||||||
|
cube = Cube((x, y, z), vol, thresh)
|
||||||
|
edge_indices = FACE_TABLE[cube.cube_index]
|
||||||
|
# cube is entirely in/out of the surface
|
||||||
|
if len(edge_indices) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# gather mesh vertices/faces by processing each cube
|
||||||
|
interp_points = [[0.0, 0.0, 0.0]] * 12
|
||||||
|
# triangle vertex IDs and positions
|
||||||
|
tri = []
|
||||||
|
ps = []
|
||||||
|
for i, edge in enumerate(edge_indices):
|
||||||
|
interp_points[edge] = cube.vert_interp(thresh, edge, vol)
|
||||||
|
|
||||||
|
# Bind interpolated vertex with a global edge_id, which
|
||||||
|
# is represented by a pair of vertex ids (v1_id, v2_id)
|
||||||
|
# corresponding to a local edge.
|
||||||
|
(v1_id, v2_id) = cube.get_vpair_from_edge(edge, W, H)
|
||||||
|
edge_id = vpair_to_edge.setdefault(
|
||||||
|
(v1_id, v2_id), len(vpair_to_edge)
|
||||||
|
)
|
||||||
|
tri.append(edge_id)
|
||||||
|
ps.append(interp_points[edge])
|
||||||
|
# when the isolevel are the same as the edge endpoints, the interploated
|
||||||
|
# vertices can share the same values, and lead to degenerate triangles.
|
||||||
|
if (
|
||||||
|
(i + 1) % 3 == 0
|
||||||
|
and ps[0] != ps[1]
|
||||||
|
and ps[1] != ps[2]
|
||||||
|
and ps[2] != ps[0]
|
||||||
|
):
|
||||||
|
for j, edge_id in enumerate(tri):
|
||||||
|
edge_id_to_v[edge_id] = ps[j]
|
||||||
|
if edge_id not in uniq_edge_id:
|
||||||
|
uniq_edge_id[edge_id] = len(verts)
|
||||||
|
verts.append(edge_id_to_v[edge_id])
|
||||||
|
faces.append([uniq_edge_id[tri[j]] for j in range(3)])
|
||||||
|
tri = []
|
||||||
|
ps = []
|
||||||
|
|
||||||
for i in range(len(volume_data_batch)):
|
|
||||||
volume_data = volume_data_batch[i]
|
|
||||||
curr_isolevel = (
|
|
||||||
((volume_data.max() + volume_data.min()) / 2).item()
|
|
||||||
if isolevel is None
|
|
||||||
else isolevel
|
|
||||||
)
|
|
||||||
edge_vertices_to_index = {}
|
|
||||||
vertex_coords_to_index = {}
|
|
||||||
verts, faces = [], []
|
|
||||||
# Use length - spacing for the bounds since we are using
|
|
||||||
# cubes of size spacing, with the lowest x,y,z values
|
|
||||||
# (bottom front left)
|
|
||||||
for x in range(0, W - spacing, spacing):
|
|
||||||
for y in range(0, H - spacing, spacing):
|
|
||||||
for z in range(0, D - spacing, spacing):
|
|
||||||
cube = Cube((x, y, z), spacing)
|
|
||||||
new_verts, new_faces = polygonise(
|
|
||||||
cube,
|
|
||||||
curr_isolevel,
|
|
||||||
volume_data,
|
|
||||||
edge_vertices_to_index,
|
|
||||||
vertex_coords_to_index,
|
|
||||||
)
|
|
||||||
verts.extend(new_verts)
|
|
||||||
faces.extend(new_faces)
|
|
||||||
if len(faces) > 0 and len(verts) > 0:
|
if len(faces) > 0 and len(verts) > 0:
|
||||||
verts = torch.tensor(verts, dtype=torch.float32)
|
verts = torch.tensor(verts, dtype=vol.dtype)
|
||||||
# Convert vertices from world to local coords
|
# Convert from world coordinates ([0, D-1], [0, H-1], [0, W-1]) to
|
||||||
|
# local coordinates in the range [-1, 1]
|
||||||
if return_local_coords:
|
if return_local_coords:
|
||||||
verts = world_to_local_transform.transform_points(verts[None, ...])
|
verts = (
|
||||||
verts = verts.squeeze()
|
Translate(x=+1.0, y=+1.0, z=+1.0, device=vol_batch.device)
|
||||||
|
.scale((vol_batch.new_tensor([W, H, D])[None] - 1) * 0.5)
|
||||||
|
.inverse()
|
||||||
|
).transform_points(verts[None])[0]
|
||||||
batched_verts.append(verts)
|
batched_verts.append(verts)
|
||||||
batched_faces.append(torch.tensor(faces, dtype=torch.int64))
|
batched_faces.append(torch.tensor(faces, dtype=torch.int64))
|
||||||
return batched_verts, batched_faces
|
|
||||||
|
|
||||||
|
|
||||||
def polygonise(
|
|
||||||
cube: Cube,
|
|
||||||
isolevel: float,
|
|
||||||
volume_data: torch.Tensor,
|
|
||||||
edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],
|
|
||||||
vertex_coords_to_index: Dict[Tuple[float, float, float], int],
|
|
||||||
) -> Tuple[list, list]:
|
|
||||||
"""
|
|
||||||
Runs the classic marching cubes algorithm for one Cube in the volume.
|
|
||||||
Returns the vertices and faces for the given cube.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cube: a Cube indicating the cube being examined for edges that intersect
|
|
||||||
the volume data.
|
|
||||||
isolevel: the isosurface value to use as the threshold to determine
|
|
||||||
whether points are within a volume.
|
|
||||||
volume_data: a Tensor of shape (D, H, W) corresponding to
|
|
||||||
a 3D scalar field
|
|
||||||
edge_vertices_to_index: A dictionary which maps an edge's two coordinates
|
|
||||||
to the index of its interpolated point, if that interpolated point
|
|
||||||
has already been used by a previous point
|
|
||||||
vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding
|
|
||||||
index of that vertex, if that point has already been marked as a vertex.
|
|
||||||
Returns:
|
|
||||||
verts: List of triangle vertices for the given cube in the volume
|
|
||||||
faces: List of triangle faces for the given cube in the volume
|
|
||||||
"""
|
|
||||||
num_existing_verts = max(edge_vertices_to_index.values(), default=-1) + 1
|
|
||||||
verts, faces = [], []
|
|
||||||
cube_index = cube.get_index(volume_data, isolevel)
|
|
||||||
edges = EDGE_TABLE[cube_index]
|
|
||||||
edge_indices = _get_edge_indices(edges)
|
|
||||||
if len(edge_indices) == 0:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
new_verts, edge_index_to_point_index = _calculate_interp_vertices(
|
|
||||||
edge_indices,
|
|
||||||
volume_data,
|
|
||||||
cube,
|
|
||||||
isolevel,
|
|
||||||
edge_vertices_to_index,
|
|
||||||
vertex_coords_to_index,
|
|
||||||
num_existing_verts,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create faces
|
|
||||||
face_triangles = FACE_TABLE[cube_index]
|
|
||||||
for i in range(0, len(face_triangles), 3):
|
|
||||||
tri1 = edge_index_to_point_index[face_triangles[i]]
|
|
||||||
tri2 = edge_index_to_point_index[face_triangles[i + 1]]
|
|
||||||
tri3 = edge_index_to_point_index[face_triangles[i + 2]]
|
|
||||||
if tri1 != tri2 and tri2 != tri3 and tri1 != tri3:
|
|
||||||
faces.append([tri1, tri2, tri3])
|
|
||||||
|
|
||||||
verts += new_verts
|
|
||||||
return verts, faces
|
|
||||||
|
|
||||||
|
|
||||||
def _get_edge_indices(edges: int) -> List[int]:
|
|
||||||
"""
|
|
||||||
Finds which edge numbers are intersected given the bit representation
|
|
||||||
detailed in marching_cubes_data.EDGE_TABLE.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
edges: an integer corresponding to the value at cube_index
|
|
||||||
from the EDGE_TABLE in marching_cubes_data.py
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
edge_indices: A list of edge indices
|
|
||||||
"""
|
|
||||||
if edges == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
edge_indices = []
|
|
||||||
for i in range(12):
|
|
||||||
if edges & (2**i):
|
|
||||||
edge_indices.append(i)
|
|
||||||
return edge_indices
|
|
||||||
|
|
||||||
|
|
||||||
def _calculate_interp_vertices(
|
|
||||||
edge_indices: List[int],
|
|
||||||
volume_data: torch.Tensor,
|
|
||||||
cube: Cube,
|
|
||||||
isolevel: float,
|
|
||||||
edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],
|
|
||||||
vertex_coords_to_index: Dict[Tuple[float, float, float], int],
|
|
||||||
num_existing_verts: int,
|
|
||||||
) -> Tuple[List, Dict[int, int]]:
|
|
||||||
"""
|
|
||||||
Finds the interpolated vertices for the intersected edges, either referencing
|
|
||||||
previous calculations or newly calculating and storing the new interpolated
|
|
||||||
points.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
edge_indices: the numbers of the edges which are intersected. See the
|
|
||||||
Cube class for more detail on the edge numbering convention.
|
|
||||||
volume_data: a Tensor of size (D, H, W) corresponding to
|
|
||||||
a 3D scalar field
|
|
||||||
cube: a Cube indicating the cube being examined for edges that intersect
|
|
||||||
the volume
|
|
||||||
isolevel: the isosurface value to use as the threshold to determine
|
|
||||||
whether points are within a volume.
|
|
||||||
edge_vertices_to_index: A dictionary which maps an edge's two coordinates
|
|
||||||
to the index of its interpolated point, if that interpolated point
|
|
||||||
has already been used by a previous point
|
|
||||||
vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding
|
|
||||||
index of that vertex, if that point has already been marked as a vertex.
|
|
||||||
num_existing_verts: the number of vertices that have been found in previous
|
|
||||||
calls to polygonise for the given volume_data in the above function, marching_cubes.
|
|
||||||
This is equal to the 1 + the maximum value in edge_vertices_to_index.
|
|
||||||
Returns:
|
|
||||||
interp_points: a list of new interpolated points
|
|
||||||
edge_index_to_point_index: a dictionary mapping an edge number to the index in the
|
|
||||||
marching cubes' vertices list of the interpolated point on that edge. To be precise,
|
|
||||||
it refers to the index within the vertices list after interp_points
|
|
||||||
has been appended to the verts list constructed in the marching_cubes_naive
|
|
||||||
function.
|
|
||||||
"""
|
|
||||||
interp_points = []
|
|
||||||
edge_index_to_point_index = {}
|
|
||||||
for edge_index in edge_indices:
|
|
||||||
v1, v2 = EDGE_TO_VERTICES[edge_index]
|
|
||||||
point1, point2 = cube.vertices[v1], cube.vertices[v2]
|
|
||||||
p_tuple1, p_tuple2 = tuple(point1.tolist()), tuple(point2.tolist())
|
|
||||||
if (p_tuple1, p_tuple2) in edge_vertices_to_index:
|
|
||||||
edge_index_to_point_index[edge_index] = edge_vertices_to_index[
|
|
||||||
(p_tuple1, p_tuple2)
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
val1, val2 = _get_value(point1, volume_data), _get_value(
|
batched_verts.append([])
|
||||||
point2, volume_data
|
batched_faces.append([])
|
||||||
)
|
return batched_verts, batched_faces
|
||||||
|
|
||||||
point = None
|
|
||||||
if abs(isolevel - val1) < EPS:
|
|
||||||
point = point1
|
|
||||||
|
|
||||||
if abs(isolevel - val2) < EPS:
|
|
||||||
point = point2
|
|
||||||
|
|
||||||
if abs(val1 - val2) < EPS:
|
|
||||||
point = point1
|
|
||||||
|
|
||||||
if point is None:
|
|
||||||
mu = (isolevel - val1) / (val2 - val1)
|
|
||||||
x1, y1, z1 = point1
|
|
||||||
x2, y2, z2 = point2
|
|
||||||
x = x1 + mu * (x2 - x1)
|
|
||||||
y = y1 + mu * (y2 - y1)
|
|
||||||
z = z1 + mu * (z2 - z1)
|
|
||||||
else:
|
|
||||||
x, y, z = point
|
|
||||||
|
|
||||||
x, y, z = x.item(), y.item(), z.item() # for dictionary keys
|
|
||||||
|
|
||||||
vert_index = None
|
|
||||||
if (x, y, z) in vertex_coords_to_index:
|
|
||||||
vert_index = vertex_coords_to_index[(x, y, z)]
|
|
||||||
else:
|
|
||||||
vert_index = num_existing_verts + len(interp_points)
|
|
||||||
interp_points.append([x, y, z])
|
|
||||||
vertex_coords_to_index[(x, y, z)] = vert_index
|
|
||||||
|
|
||||||
edge_vertices_to_index[(p_tuple1, p_tuple2)] = vert_index
|
|
||||||
edge_index_to_point_index[edge_index] = vert_index
|
|
||||||
|
|
||||||
return interp_points, edge_index_to_point_index
|
|
||||||
|
|
||||||
|
|
||||||
def _get_value(point: Tuple[int, int, int], volume_data: torch.Tensor) -> float:
|
|
||||||
"""
|
|
||||||
Gets the value at a given coordinate point in the scalar field.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
point: data of shape (3) corresponding to an xyz coordinate.
|
|
||||||
volume_data: a Tensor of size (D, H, W) corresponding to
|
|
||||||
a 3D scalar field
|
|
||||||
Returns:
|
|
||||||
data: scalar value in the volume at the given point
|
|
||||||
"""
|
|
||||||
x, y, z = point
|
|
||||||
# pyre-fixme[7]: Expected `float` but got `Tensor`.
|
|
||||||
return volume_data[z][y][x]
|
|
||||||
|
@ -4,284 +4,21 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
# A length 256 list which maps a cubeindex to a number
|
|
||||||
# with the intersected edges' bits set to 1.
|
|
||||||
# Each cubeindex corresponds to a given cube configuration, where
|
|
||||||
# it is composed of a bitstring where the 0th bit is flipped if vertex 0
|
|
||||||
# is below the isosurface (i.e. 0x01), for each of the 8 vertices.
|
|
||||||
EDGE_TABLE = [
|
|
||||||
0x0,
|
|
||||||
0x109,
|
|
||||||
0x203,
|
|
||||||
0x30A,
|
|
||||||
0x406,
|
|
||||||
0x50F,
|
|
||||||
0x605,
|
|
||||||
0x70C,
|
|
||||||
0x80C,
|
|
||||||
0x905,
|
|
||||||
0xA0F,
|
|
||||||
0xB06,
|
|
||||||
0xC0A,
|
|
||||||
0xD03,
|
|
||||||
0xE09,
|
|
||||||
0xF00,
|
|
||||||
0x190,
|
|
||||||
0x99,
|
|
||||||
0x393,
|
|
||||||
0x29A,
|
|
||||||
0x596,
|
|
||||||
0x49F,
|
|
||||||
0x795,
|
|
||||||
0x69C,
|
|
||||||
0x99C,
|
|
||||||
0x895,
|
|
||||||
0xB9F,
|
|
||||||
0xA96,
|
|
||||||
0xD9A,
|
|
||||||
0xC93,
|
|
||||||
0xF99,
|
|
||||||
0xE90,
|
|
||||||
0x230,
|
|
||||||
0x339,
|
|
||||||
0x33,
|
|
||||||
0x13A,
|
|
||||||
0x636,
|
|
||||||
0x73F,
|
|
||||||
0x435,
|
|
||||||
0x53C,
|
|
||||||
0xA3C,
|
|
||||||
0xB35,
|
|
||||||
0x83F,
|
|
||||||
0x936,
|
|
||||||
0xE3A,
|
|
||||||
0xF33,
|
|
||||||
0xC39,
|
|
||||||
0xD30,
|
|
||||||
0x3A0,
|
|
||||||
0x2A9,
|
|
||||||
0x1A3,
|
|
||||||
0xAA,
|
|
||||||
0x7A6,
|
|
||||||
0x6AF,
|
|
||||||
0x5A5,
|
|
||||||
0x4AC,
|
|
||||||
0xBAC,
|
|
||||||
0xAA5,
|
|
||||||
0x9AF,
|
|
||||||
0x8A6,
|
|
||||||
0xFAA,
|
|
||||||
0xEA3,
|
|
||||||
0xDA9,
|
|
||||||
0xCA0,
|
|
||||||
0x460,
|
|
||||||
0x569,
|
|
||||||
0x663,
|
|
||||||
0x76A,
|
|
||||||
0x66,
|
|
||||||
0x16F,
|
|
||||||
0x265,
|
|
||||||
0x36C,
|
|
||||||
0xC6C,
|
|
||||||
0xD65,
|
|
||||||
0xE6F,
|
|
||||||
0xF66,
|
|
||||||
0x86A,
|
|
||||||
0x963,
|
|
||||||
0xA69,
|
|
||||||
0xB60,
|
|
||||||
0x5F0,
|
|
||||||
0x4F9,
|
|
||||||
0x7F3,
|
|
||||||
0x6FA,
|
|
||||||
0x1F6,
|
|
||||||
0xFF,
|
|
||||||
0x3F5,
|
|
||||||
0x2FC,
|
|
||||||
0xDFC,
|
|
||||||
0xCF5,
|
|
||||||
0xFFF,
|
|
||||||
0xEF6,
|
|
||||||
0x9FA,
|
|
||||||
0x8F3,
|
|
||||||
0xBF9,
|
|
||||||
0xAF0,
|
|
||||||
0x650,
|
|
||||||
0x759,
|
|
||||||
0x453,
|
|
||||||
0x55A,
|
|
||||||
0x256,
|
|
||||||
0x35F,
|
|
||||||
0x55,
|
|
||||||
0x15C,
|
|
||||||
0xE5C,
|
|
||||||
0xF55,
|
|
||||||
0xC5F,
|
|
||||||
0xD56,
|
|
||||||
0xA5A,
|
|
||||||
0xB53,
|
|
||||||
0x859,
|
|
||||||
0x950,
|
|
||||||
0x7C0,
|
|
||||||
0x6C9,
|
|
||||||
0x5C3,
|
|
||||||
0x4CA,
|
|
||||||
0x3C6,
|
|
||||||
0x2CF,
|
|
||||||
0x1C5,
|
|
||||||
0xCC,
|
|
||||||
0xFCC,
|
|
||||||
0xEC5,
|
|
||||||
0xDCF,
|
|
||||||
0xCC6,
|
|
||||||
0xBCA,
|
|
||||||
0xAC3,
|
|
||||||
0x9C9,
|
|
||||||
0x8C0,
|
|
||||||
0x8C0,
|
|
||||||
0x9C9,
|
|
||||||
0xAC3,
|
|
||||||
0xBCA,
|
|
||||||
0xCC6,
|
|
||||||
0xDCF,
|
|
||||||
0xEC5,
|
|
||||||
0xFCC,
|
|
||||||
0xCC,
|
|
||||||
0x1C5,
|
|
||||||
0x2CF,
|
|
||||||
0x3C6,
|
|
||||||
0x4CA,
|
|
||||||
0x5C3,
|
|
||||||
0x6C9,
|
|
||||||
0x7C0,
|
|
||||||
0x950,
|
|
||||||
0x859,
|
|
||||||
0xB53,
|
|
||||||
0xA5A,
|
|
||||||
0xD56,
|
|
||||||
0xC5F,
|
|
||||||
0xF55,
|
|
||||||
0xE5C,
|
|
||||||
0x15C,
|
|
||||||
0x55,
|
|
||||||
0x35F,
|
|
||||||
0x256,
|
|
||||||
0x55A,
|
|
||||||
0x453,
|
|
||||||
0x759,
|
|
||||||
0x650,
|
|
||||||
0xAF0,
|
|
||||||
0xBF9,
|
|
||||||
0x8F3,
|
|
||||||
0x9FA,
|
|
||||||
0xEF6,
|
|
||||||
0xFFF,
|
|
||||||
0xCF5,
|
|
||||||
0xDFC,
|
|
||||||
0x2FC,
|
|
||||||
0x3F5,
|
|
||||||
0xFF,
|
|
||||||
0x1F6,
|
|
||||||
0x6FA,
|
|
||||||
0x7F3,
|
|
||||||
0x4F9,
|
|
||||||
0x5F0,
|
|
||||||
0xB60,
|
|
||||||
0xA69,
|
|
||||||
0x963,
|
|
||||||
0x86A,
|
|
||||||
0xF66,
|
|
||||||
0xE6F,
|
|
||||||
0xD65,
|
|
||||||
0xC6C,
|
|
||||||
0x36C,
|
|
||||||
0x265,
|
|
||||||
0x16F,
|
|
||||||
0x66,
|
|
||||||
0x76A,
|
|
||||||
0x663,
|
|
||||||
0x569,
|
|
||||||
0x460,
|
|
||||||
0xCA0,
|
|
||||||
0xDA9,
|
|
||||||
0xEA3,
|
|
||||||
0xFAA,
|
|
||||||
0x8A6,
|
|
||||||
0x9AF,
|
|
||||||
0xAA5,
|
|
||||||
0xBAC,
|
|
||||||
0x4AC,
|
|
||||||
0x5A5,
|
|
||||||
0x6AF,
|
|
||||||
0x7A6,
|
|
||||||
0xAA,
|
|
||||||
0x1A3,
|
|
||||||
0x2A9,
|
|
||||||
0x3A0,
|
|
||||||
0xD30,
|
|
||||||
0xC39,
|
|
||||||
0xF33,
|
|
||||||
0xE3A,
|
|
||||||
0x936,
|
|
||||||
0x83F,
|
|
||||||
0xB35,
|
|
||||||
0xA3C,
|
|
||||||
0x53C,
|
|
||||||
0x435,
|
|
||||||
0x73F,
|
|
||||||
0x636,
|
|
||||||
0x13A,
|
|
||||||
0x33,
|
|
||||||
0x339,
|
|
||||||
0x230,
|
|
||||||
0xE90,
|
|
||||||
0xF99,
|
|
||||||
0xC93,
|
|
||||||
0xD9A,
|
|
||||||
0xA96,
|
|
||||||
0xB9F,
|
|
||||||
0x895,
|
|
||||||
0x99C,
|
|
||||||
0x69C,
|
|
||||||
0x795,
|
|
||||||
0x49F,
|
|
||||||
0x596,
|
|
||||||
0x29A,
|
|
||||||
0x393,
|
|
||||||
0x99,
|
|
||||||
0x190,
|
|
||||||
0xF00,
|
|
||||||
0xE09,
|
|
||||||
0xD03,
|
|
||||||
0xC0A,
|
|
||||||
0xB06,
|
|
||||||
0xA0F,
|
|
||||||
0x905,
|
|
||||||
0x80C,
|
|
||||||
0x70C,
|
|
||||||
0x605,
|
|
||||||
0x50F,
|
|
||||||
0x406,
|
|
||||||
0x30A,
|
|
||||||
0x203,
|
|
||||||
0x109,
|
|
||||||
0x0,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Maps each edge (by index) to the corresponding cube vertices
|
# Maps each edge (by index) to the corresponding cube vertices
|
||||||
EDGE_TO_VERTICES = [
|
EDGE_TO_VERTICES = [
|
||||||
[0, 1],
|
[0, 1],
|
||||||
[1, 2],
|
|
||||||
[3, 2],
|
|
||||||
[0, 3],
|
|
||||||
[4, 5],
|
|
||||||
[5, 6],
|
|
||||||
[7, 6],
|
|
||||||
[4, 7],
|
|
||||||
[0, 4],
|
|
||||||
[1, 5],
|
[1, 5],
|
||||||
[2, 6],
|
[4, 5],
|
||||||
|
[0, 4],
|
||||||
|
[2, 3],
|
||||||
[3, 7],
|
[3, 7],
|
||||||
|
[6, 7],
|
||||||
|
[2, 6],
|
||||||
|
[0, 2],
|
||||||
|
[1, 3],
|
||||||
|
[5, 7],
|
||||||
|
[4, 6],
|
||||||
]
|
]
|
||||||
|
|
||||||
# A list of lists mapping a cube_index (a given configuration)
|
# A list of lists mapping a cube_index (a given configuration)
|
||||||
@ -547,3 +284,6 @@ FACE_TABLE = [
|
|||||||
[0, 3, 8],
|
[0, 3, 8],
|
||||||
[],
|
[],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# mapping from 0-7 to v0-v7 in cube.vertices
|
||||||
|
INDEX = [0, 1, 5, 4, 2, 3, 7, 6]
|
||||||
|
Binary file not shown.
Binary file not shown.
@ -32,8 +32,8 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
volume_data = torch.ones(1, 2, 2, 2) # (B, W, H, D)
|
||||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||||
|
|
||||||
expected_verts = torch.tensor([])
|
expected_verts = torch.tensor([[]])
|
||||||
expected_faces = torch.tensor([], dtype=torch.int64)
|
expected_faces = torch.tensor([[]], dtype=torch.int64)
|
||||||
self.assertClose(verts, expected_verts)
|
self.assertClose(verts, expected_verts)
|
||||||
self.assertClose(faces, expected_faces)
|
self.assertClose(faces, expected_faces)
|
||||||
|
|
||||||
@ -42,16 +42,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
volume_data[0, 0, 0, 0] = 0
|
volume_data[0, 0, 0, 0] = 0
|
||||||
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
volume_data = volume_data.permute(0, 3, 2, 1) # (B, D, H, W)
|
||||||
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
verts, faces = marching_cubes_naive(volume_data, return_local_coords=False)
|
||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[0.5, 0, 0],
|
[0.5, 0, 0],
|
||||||
[0, 0, 0.5],
|
|
||||||
[0, 0.5, 0],
|
[0, 0.5, 0],
|
||||||
|
[0, 0, 0.5],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_faces = torch.tensor([[1, 2, 0]])
|
expected_faces = torch.tensor([[0, 1, 2]])
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
|
|
||||||
@ -69,12 +68,12 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[1.0000, 0.0000, 0.5000],
|
[1.0000, 0.0000, 0.5000],
|
||||||
|
[0.0000, 0.5000, 0.0000],
|
||||||
[0.0000, 0.0000, 0.5000],
|
[0.0000, 0.0000, 0.5000],
|
||||||
[1.0000, 0.5000, 0.0000],
|
[1.0000, 0.5000, 0.0000],
|
||||||
[0.0000, 0.5000, 0.0000],
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
expected_faces = torch.tensor([[1, 2, 0], [3, 2, 1]])
|
expected_faces = torch.tensor([[0, 1, 2], [3, 1, 0]])
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
|
|
||||||
@ -92,15 +91,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[0.5000, 0.0000, 0.0000],
|
[1.0000, 0.5000, 0.0000],
|
||||||
[0.0000, 0.0000, 0.5000],
|
|
||||||
[1.0000, 1.0000, 0.5000],
|
[1.0000, 1.0000, 0.5000],
|
||||||
[0.5000, 1.0000, 0.0000],
|
[0.5000, 1.0000, 0.0000],
|
||||||
[1.0000, 0.5000, 0.0000],
|
[0.5000, 0.0000, 0.0000],
|
||||||
[0.0000, 0.5000, 0.0000],
|
[0.0000, 0.5000, 0.0000],
|
||||||
|
[0.0000, 0.0000, 0.5000],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
expected_faces = torch.tensor([[0, 1, 5], [4, 3, 2]])
|
expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5]])
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
|
|
||||||
@ -119,14 +118,14 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[0.5000, 0.0000, 0.0000],
|
|
||||||
[0.0000, 0.0000, 0.5000],
|
[0.0000, 0.0000, 0.5000],
|
||||||
|
[1.0000, 0.5000, 0.0000],
|
||||||
|
[0.5000, 0.0000, 0.0000],
|
||||||
[0.0000, 0.5000, 1.0000],
|
[0.0000, 0.5000, 1.0000],
|
||||||
[1.0000, 0.5000, 1.0000],
|
[1.0000, 0.5000, 1.0000],
|
||||||
[1.0000, 0.5000, 0.0000],
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
expected_faces = torch.tensor([[0, 2, 1], [0, 4, 2], [4, 3, 2]])
|
expected_faces = torch.tensor([[0, 1, 2], [0, 3, 1], [3, 4, 1]])
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
|
|
||||||
@ -143,15 +142,14 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[0.0000, 0.5000, 1.0000],
|
|
||||||
[1.0000, 0.5000, 1.0000],
|
|
||||||
[1.0000, 0.5000, 0.0000],
|
[1.0000, 0.5000, 0.0000],
|
||||||
[0.0000, 0.5000, 0.0000],
|
[0.0000, 0.5000, 0.0000],
|
||||||
|
[1.0000, 0.5000, 1.0000],
|
||||||
|
[0.0000, 0.5000, 1.0000],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_faces = torch.tensor([[1, 0, 2], [2, 0, 3]])
|
expected_faces = torch.tensor([[0, 1, 2], [2, 1, 3]])
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
|
|
||||||
@ -171,17 +169,17 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[0.5000, 0.0000, 0.0000],
|
|
||||||
[0.0000, 0.0000, 0.5000],
|
|
||||||
[0.5000, 1.0000, 0.0000],
|
[0.5000, 1.0000, 0.0000],
|
||||||
[0.0000, 1.0000, 0.5000],
|
[0.0000, 1.0000, 0.5000],
|
||||||
|
[0.0000, 0.5000, 0.0000],
|
||||||
|
[1.0000, 0.5000, 0.0000],
|
||||||
|
[0.5000, 0.0000, 0.0000],
|
||||||
[0.0000, 0.5000, 1.0000],
|
[0.0000, 0.5000, 1.0000],
|
||||||
[1.0000, 0.5000, 1.0000],
|
[1.0000, 0.5000, 1.0000],
|
||||||
[1.0000, 0.5000, 0.0000],
|
[0.0000, 0.0000, 0.5000],
|
||||||
[0.0000, 0.5000, 0.0000],
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
expected_faces = torch.tensor([[2, 7, 3], [0, 6, 1], [6, 4, 1], [6, 5, 4]])
|
expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [3, 5, 6], [5, 4, 7]])
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
@ -202,22 +200,22 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[0.5000, 0.0000, 1.0000],
|
|
||||||
[1.0000, 0.0000, 0.5000],
|
|
||||||
[0.5000, 0.0000, 0.0000],
|
|
||||||
[0.0000, 0.0000, 0.5000],
|
|
||||||
[0.5000, 1.0000, 1.0000],
|
[0.5000, 1.0000, 1.0000],
|
||||||
[1.0000, 1.0000, 0.5000],
|
|
||||||
[0.5000, 1.0000, 0.0000],
|
|
||||||
[0.0000, 1.0000, 0.5000],
|
|
||||||
[0.0000, 0.5000, 1.0000],
|
[0.0000, 0.5000, 1.0000],
|
||||||
|
[0.0000, 1.0000, 0.5000],
|
||||||
|
[1.0000, 0.0000, 0.5000],
|
||||||
|
[0.5000, 0.0000, 1.0000],
|
||||||
[1.0000, 0.5000, 1.0000],
|
[1.0000, 0.5000, 1.0000],
|
||||||
[1.0000, 0.5000, 0.0000],
|
[0.5000, 0.0000, 0.0000],
|
||||||
[0.0000, 0.5000, 0.0000],
|
[0.0000, 0.5000, 0.0000],
|
||||||
|
[0.0000, 0.0000, 0.5000],
|
||||||
|
[0.5000, 1.0000, 0.0000],
|
||||||
|
[1.0000, 0.5000, 0.0000],
|
||||||
|
[1.0000, 1.0000, 0.5000],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_faces = torch.tensor([[0, 1, 9], [4, 7, 8], [2, 3, 11], [5, 10, 6]])
|
expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]])
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
@ -238,15 +236,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[1.0000, 0.0000, 0.5000],
|
|
||||||
[0.5000, 0.0000, 0.0000],
|
|
||||||
[0.5000, 1.0000, 1.0000],
|
|
||||||
[0.0000, 1.0000, 0.5000],
|
|
||||||
[1.0000, 0.5000, 1.0000],
|
[1.0000, 0.5000, 1.0000],
|
||||||
|
[0.0000, 1.0000, 0.5000],
|
||||||
|
[0.5000, 1.0000, 1.0000],
|
||||||
|
[1.0000, 0.0000, 0.5000],
|
||||||
[0.0000, 0.5000, 0.0000],
|
[0.0000, 0.5000, 0.0000],
|
||||||
|
[0.5000, 0.0000, 0.0000],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
expected_faces = torch.tensor([[2, 3, 5], [4, 2, 5], [4, 5, 1], [4, 1, 0]])
|
expected_faces = torch.tensor([[0, 1, 2], [3, 1, 0], [3, 4, 1], [3, 5, 4]])
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
@ -269,13 +267,13 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
[
|
[
|
||||||
[0.5000, 0.0000, 0.0000],
|
[0.5000, 0.0000, 0.0000],
|
||||||
[0.0000, 0.0000, 0.5000],
|
[0.0000, 0.0000, 0.5000],
|
||||||
[0.5000, 1.0000, 1.0000],
|
|
||||||
[0.0000, 1.0000, 0.5000],
|
[0.0000, 1.0000, 0.5000],
|
||||||
[1.0000, 0.5000, 1.0000],
|
[1.0000, 0.5000, 1.0000],
|
||||||
[1.0000, 0.5000, 0.0000],
|
[1.0000, 0.5000, 0.0000],
|
||||||
|
[0.5000, 1.0000, 1.0000],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
expected_faces = torch.tensor([[0, 5, 4], [0, 4, 3], [0, 3, 1], [3, 4, 2]])
|
expected_faces = torch.tensor([[0, 1, 2], [0, 2, 3], [0, 3, 4], [5, 3, 2]])
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
@ -295,15 +293,15 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[0.5000, 0.0000, 0.0000],
|
[0.5000, 0.0000, 0.0000],
|
||||||
|
[0.0000, 0.5000, 0.0000],
|
||||||
[0.0000, 0.0000, 0.5000],
|
[0.0000, 0.0000, 0.5000],
|
||||||
[0.5000, 1.0000, 1.0000],
|
|
||||||
[1.0000, 1.0000, 0.5000],
|
[1.0000, 1.0000, 0.5000],
|
||||||
[1.0000, 0.5000, 1.0000],
|
[1.0000, 0.5000, 1.0000],
|
||||||
[0.0000, 0.5000, 0.0000],
|
[0.5000, 1.0000, 1.0000],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_faces = torch.tensor([[4, 3, 2], [0, 1, 5]])
|
expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5]])
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
@ -324,16 +322,16 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[1.0000, 0.0000, 0.5000],
|
[1.0000, 0.0000, 0.5000],
|
||||||
|
[0.0000, 0.5000, 0.0000],
|
||||||
[0.0000, 0.0000, 0.5000],
|
[0.0000, 0.0000, 0.5000],
|
||||||
[0.5000, 1.0000, 1.0000],
|
[1.0000, 0.5000, 0.0000],
|
||||||
[1.0000, 1.0000, 0.5000],
|
[1.0000, 1.0000, 0.5000],
|
||||||
[1.0000, 0.5000, 1.0000],
|
[1.0000, 0.5000, 1.0000],
|
||||||
[1.0000, 0.5000, 0.0000],
|
[0.5000, 1.0000, 1.0000],
|
||||||
[0.0000, 0.5000, 0.0000],
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_faces = torch.tensor([[5, 1, 6], [5, 0, 1], [4, 3, 2]])
|
expected_faces = torch.tensor([[0, 1, 2], [0, 3, 1], [4, 5, 6]])
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
@ -354,18 +352,18 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[1.0000, 0.0000, 0.5000],
|
[1.0000, 0.0000, 0.5000],
|
||||||
|
[1.0000, 0.5000, 0.0000],
|
||||||
[0.5000, 0.0000, 0.0000],
|
[0.5000, 0.0000, 0.0000],
|
||||||
[0.5000, 1.0000, 1.0000],
|
|
||||||
[1.0000, 1.0000, 0.5000],
|
[1.0000, 1.0000, 0.5000],
|
||||||
|
[1.0000, 0.5000, 1.0000],
|
||||||
|
[0.5000, 1.0000, 1.0000],
|
||||||
|
[0.0000, 0.5000, 0.0000],
|
||||||
[0.5000, 1.0000, 0.0000],
|
[0.5000, 1.0000, 0.0000],
|
||||||
[0.0000, 1.0000, 0.5000],
|
[0.0000, 1.0000, 0.5000],
|
||||||
[1.0000, 0.5000, 1.0000],
|
|
||||||
[1.0000, 0.5000, 0.0000],
|
|
||||||
[0.0000, 0.5000, 0.0000],
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_faces = torch.tensor([[6, 3, 2], [7, 0, 1], [5, 4, 8]])
|
expected_faces = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
@ -386,18 +384,18 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[0.5000, 0.0000, 1.0000],
|
|
||||||
[1.0000, 0.0000, 0.5000],
|
[1.0000, 0.0000, 0.5000],
|
||||||
[0.5000, 0.0000, 0.0000],
|
[0.5000, 0.0000, 1.0000],
|
||||||
[0.0000, 0.0000, 0.5000],
|
|
||||||
[0.5000, 1.0000, 1.0000],
|
|
||||||
[1.0000, 1.0000, 0.5000],
|
[1.0000, 1.0000, 0.5000],
|
||||||
|
[0.5000, 1.0000, 1.0000],
|
||||||
|
[0.0000, 0.0000, 0.5000],
|
||||||
|
[0.5000, 0.0000, 0.0000],
|
||||||
[0.5000, 1.0000, 0.0000],
|
[0.5000, 1.0000, 0.0000],
|
||||||
[0.0000, 1.0000, 0.5000],
|
[0.0000, 1.0000, 0.5000],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_faces = torch.tensor([[3, 6, 2], [3, 7, 6], [1, 5, 0], [5, 4, 0]])
|
expected_faces = torch.tensor([[0, 1, 2], [2, 1, 3], [4, 5, 6], [4, 6, 7]])
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
@ -418,16 +416,16 @@ class TestCubeConfiguration(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[1.0000, 0.0000, 0.5000],
|
|
||||||
[0.5000, 0.0000, 0.0000],
|
[0.5000, 0.0000, 0.0000],
|
||||||
[0.5000, 1.0000, 1.0000],
|
|
||||||
[1.0000, 1.0000, 0.5000],
|
|
||||||
[0.0000, 0.5000, 1.0000],
|
|
||||||
[0.0000, 0.5000, 0.0000],
|
[0.0000, 0.5000, 0.0000],
|
||||||
|
[0.0000, 0.5000, 1.0000],
|
||||||
|
[1.0000, 1.0000, 0.5000],
|
||||||
|
[1.0000, 0.0000, 0.5000],
|
||||||
|
[0.5000, 1.0000, 1.0000],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_faces = torch.tensor([[1, 0, 3], [1, 3, 4], [1, 4, 5], [2, 4, 3]])
|
expected_faces = torch.tensor([[0, 1, 2], [0, 2, 3], [0, 3, 4], [3, 2, 5]])
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
@ -447,27 +445,26 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[0.5, 1, 1],
|
[1.0000, 0.5000, 1.0000],
|
||||||
[1, 1, 0.5],
|
[1.0000, 1.0000, 0.5000],
|
||||||
[1, 0.5, 1],
|
[0.5000, 1.0000, 1.0000],
|
||||||
[1, 1, 1.5],
|
[1.5000, 1.0000, 1.0000],
|
||||||
[1, 1.5, 1],
|
[1.0000, 1.5000, 1.0000],
|
||||||
[1.5, 1, 1],
|
[1.0000, 1.0000, 1.5000],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
expected_faces = torch.tensor(
|
expected_faces = torch.tensor(
|
||||||
[
|
[
|
||||||
[2, 0, 1],
|
[0, 1, 2],
|
||||||
[2, 3, 0],
|
[1, 0, 3],
|
||||||
[0, 4, 1],
|
[1, 4, 2],
|
||||||
[3, 4, 0],
|
[1, 3, 4],
|
||||||
[5, 2, 1],
|
[0, 2, 5],
|
||||||
[3, 2, 5],
|
[3, 0, 5],
|
||||||
[5, 1, 4],
|
[2, 4, 5],
|
||||||
[3, 5, 4],
|
[3, 5, 4],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
|
|
||||||
@ -492,76 +489,76 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[0.9000, 1.0000, 1.0000],
|
|
||||||
[1.0000, 1.0000, 0.9000],
|
|
||||||
[1.0000, 0.9000, 1.0000],
|
[1.0000, 0.9000, 1.0000],
|
||||||
[0.9000, 1.0000, 2.0000],
|
[1.0000, 1.0000, 0.9000],
|
||||||
[1.0000, 0.9000, 2.0000],
|
[0.9000, 1.0000, 1.0000],
|
||||||
[1.0000, 1.0000, 2.1000],
|
|
||||||
[0.9000, 2.0000, 1.0000],
|
|
||||||
[1.0000, 2.0000, 0.9000],
|
|
||||||
[0.9000, 2.0000, 2.0000],
|
|
||||||
[1.0000, 2.0000, 2.1000],
|
|
||||||
[1.0000, 2.1000, 1.0000],
|
|
||||||
[1.0000, 2.1000, 2.0000],
|
|
||||||
[2.0000, 1.0000, 0.9000],
|
|
||||||
[2.0000, 0.9000, 1.0000],
|
[2.0000, 0.9000, 1.0000],
|
||||||
[2.0000, 0.9000, 2.0000],
|
[2.0000, 1.0000, 0.9000],
|
||||||
[2.0000, 1.0000, 2.1000],
|
|
||||||
[2.0000, 2.0000, 0.9000],
|
|
||||||
[2.0000, 2.0000, 2.1000],
|
|
||||||
[2.0000, 2.1000, 1.0000],
|
|
||||||
[2.0000, 2.1000, 2.0000],
|
|
||||||
[2.1000, 1.0000, 1.0000],
|
[2.1000, 1.0000, 1.0000],
|
||||||
[2.1000, 1.0000, 2.0000],
|
[1.0000, 2.0000, 0.9000],
|
||||||
|
[0.9000, 2.0000, 1.0000],
|
||||||
|
[2.0000, 2.0000, 0.9000],
|
||||||
[2.1000, 2.0000, 1.0000],
|
[2.1000, 2.0000, 1.0000],
|
||||||
|
[1.0000, 2.1000, 1.0000],
|
||||||
|
[2.0000, 2.1000, 1.0000],
|
||||||
|
[1.0000, 0.9000, 2.0000],
|
||||||
|
[0.9000, 1.0000, 2.0000],
|
||||||
|
[2.0000, 0.9000, 2.0000],
|
||||||
|
[2.1000, 1.0000, 2.0000],
|
||||||
|
[0.9000, 2.0000, 2.0000],
|
||||||
[2.1000, 2.0000, 2.0000],
|
[2.1000, 2.0000, 2.0000],
|
||||||
|
[1.0000, 2.1000, 2.0000],
|
||||||
|
[2.0000, 2.1000, 2.0000],
|
||||||
|
[1.0000, 1.0000, 2.1000],
|
||||||
|
[2.0000, 1.0000, 2.1000],
|
||||||
|
[1.0000, 2.0000, 2.1000],
|
||||||
|
[2.0000, 2.0000, 2.1000],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_faces = torch.tensor(
|
expected_faces = torch.tensor(
|
||||||
[
|
[
|
||||||
[2, 0, 1],
|
[0, 1, 2],
|
||||||
[2, 4, 3],
|
[0, 3, 4],
|
||||||
[0, 2, 3],
|
[1, 0, 4],
|
||||||
[4, 5, 3],
|
[4, 3, 5],
|
||||||
[0, 6, 7],
|
[1, 6, 7],
|
||||||
[1, 0, 7],
|
[2, 1, 7],
|
||||||
[3, 8, 0],
|
[4, 8, 1],
|
||||||
[8, 6, 0],
|
[1, 8, 6],
|
||||||
[5, 9, 8],
|
[8, 4, 5],
|
||||||
[3, 5, 8],
|
[9, 8, 5],
|
||||||
[6, 10, 7],
|
[6, 10, 7],
|
||||||
[11, 10, 6],
|
[6, 8, 11],
|
||||||
[8, 11, 6],
|
[10, 6, 11],
|
||||||
[9, 11, 8],
|
[8, 9, 11],
|
||||||
[13, 2, 1],
|
[12, 0, 2],
|
||||||
[12, 13, 1],
|
[13, 12, 2],
|
||||||
[14, 4, 13],
|
[3, 0, 14],
|
||||||
[13, 4, 2],
|
[14, 0, 12],
|
||||||
[4, 14, 15],
|
[15, 5, 3],
|
||||||
[5, 4, 15],
|
[14, 15, 3],
|
||||||
[12, 1, 16],
|
[2, 7, 13],
|
||||||
[1, 7, 16],
|
[7, 16, 13],
|
||||||
[15, 17, 5],
|
[5, 15, 9],
|
||||||
[5, 17, 9],
|
[9, 15, 17],
|
||||||
[16, 7, 10],
|
[10, 18, 16],
|
||||||
[18, 16, 10],
|
[7, 10, 16],
|
||||||
[19, 18, 11],
|
[11, 19, 10],
|
||||||
[18, 10, 11],
|
[19, 18, 10],
|
||||||
[9, 17, 19],
|
[9, 17, 19],
|
||||||
[11, 9, 19],
|
[11, 9, 19],
|
||||||
[20, 13, 12],
|
[12, 13, 20],
|
||||||
[20, 21, 14],
|
[14, 12, 20],
|
||||||
[13, 20, 14],
|
[21, 14, 20],
|
||||||
[15, 14, 21],
|
[15, 14, 21],
|
||||||
[22, 20, 12],
|
[13, 16, 22],
|
||||||
[16, 22, 12],
|
[20, 13, 22],
|
||||||
[21, 20, 23],
|
[21, 20, 23],
|
||||||
[23, 20, 22],
|
[20, 22, 23],
|
||||||
[17, 15, 21],
|
[17, 15, 21],
|
||||||
[23, 17, 21],
|
[23, 17, 21],
|
||||||
[22, 16, 18],
|
[16, 18, 22],
|
||||||
[23, 22, 18],
|
[23, 22, 18],
|
||||||
[19, 23, 18],
|
[19, 23, 18],
|
||||||
[17, 23, 19],
|
[17, 23, 19],
|
||||||
@ -569,6 +566,7 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
|
|
||||||
verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=True)
|
verts, faces = marching_cubes_naive(volume_data, 0.9, return_local_coords=True)
|
||||||
expected_verts = convert_to_local(expected_verts, 5)
|
expected_verts = convert_to_local(expected_verts, 5)
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
@ -592,34 +590,49 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
[
|
[
|
||||||
|
[2.0, 1.0, 1.0],
|
||||||
|
[2.0, 2.0, 1.0],
|
||||||
[1.0, 1.0, 1.0],
|
[1.0, 1.0, 1.0],
|
||||||
[1.0, 1.0, 2.0],
|
|
||||||
[1.0, 2.0, 1.0],
|
[1.0, 2.0, 1.0],
|
||||||
|
[2.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
[2.0, 1.0, 2.0],
|
||||||
|
[1.0, 1.0, 2.0],
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
[1.0, 2.0, 1.0],
|
||||||
|
[1.0, 1.0, 2.0],
|
||||||
[1.0, 2.0, 2.0],
|
[1.0, 2.0, 2.0],
|
||||||
[2.0, 1.0, 1.0],
|
[2.0, 1.0, 1.0],
|
||||||
[2.0, 1.0, 2.0],
|
[2.0, 1.0, 2.0],
|
||||||
[2.0, 2.0, 1.0],
|
[2.0, 2.0, 1.0],
|
||||||
[2.0, 2.0, 2.0],
|
[2.0, 2.0, 2.0],
|
||||||
|
[2.0, 2.0, 1.0],
|
||||||
|
[2.0, 2.0, 2.0],
|
||||||
|
[1.0, 2.0, 1.0],
|
||||||
|
[1.0, 2.0, 2.0],
|
||||||
|
[2.0, 1.0, 2.0],
|
||||||
|
[1.0, 1.0, 2.0],
|
||||||
|
[2.0, 2.0, 2.0],
|
||||||
|
[1.0, 2.0, 2.0],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_faces = torch.tensor(
|
expected_faces = torch.tensor(
|
||||||
[
|
[
|
||||||
[1, 3, 0],
|
[0, 1, 2],
|
||||||
[3, 2, 0],
|
[2, 1, 3],
|
||||||
[5, 1, 4],
|
[4, 5, 6],
|
||||||
[4, 1, 0],
|
[6, 5, 7],
|
||||||
[4, 0, 6],
|
[8, 9, 10],
|
||||||
[0, 2, 6],
|
[9, 11, 10],
|
||||||
[5, 7, 1],
|
[12, 13, 14],
|
||||||
[1, 7, 3],
|
[14, 13, 15],
|
||||||
[7, 6, 3],
|
[16, 17, 18],
|
||||||
[6, 2, 3],
|
[17, 19, 18],
|
||||||
[5, 4, 7],
|
[20, 21, 22],
|
||||||
[7, 4, 6],
|
[21, 23, 22],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
|
|
||||||
@ -651,8 +664,8 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
|||||||
filename = os.path.join(DATA_DIR, data_filename)
|
filename = os.path.join(DATA_DIR, data_filename)
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
verts_and_faces = pickle.load(file)
|
verts_and_faces = pickle.load(file)
|
||||||
expected_verts = verts_and_faces["verts"].squeeze()
|
expected_verts = verts_and_faces["verts"]
|
||||||
expected_faces = verts_and_faces["faces"].squeeze()
|
expected_faces = verts_and_faces["faces"]
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts)
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces)
|
self.assertClose(faces[0], expected_faces)
|
||||||
@ -689,8 +702,8 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
|||||||
expected_verts = verts_and_faces["verts"]
|
expected_verts = verts_and_faces["verts"]
|
||||||
expected_faces = verts_and_faces["faces"]
|
expected_faces = verts_and_faces["faces"]
|
||||||
|
|
||||||
self.assertClose(verts[0], expected_verts[0])
|
self.assertClose(verts[0], expected_verts)
|
||||||
self.assertClose(faces[0], expected_faces[0])
|
self.assertClose(faces[0], expected_faces)
|
||||||
|
|
||||||
def test_cube_surface_area(self):
|
def test_cube_surface_area(self):
|
||||||
if USE_SCIKIT:
|
if USE_SCIKIT:
|
||||||
@ -760,16 +773,26 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertClose(surf, surf_sci)
|
self.assertClose(surf, surf_sci)
|
||||||
|
|
||||||
|
def test_ball_example(self):
|
||||||
|
N = 15
|
||||||
|
axis_tensor = torch.arange(0, N)
|
||||||
|
X, Y, Z = torch.meshgrid(axis_tensor, axis_tensor, axis_tensor, indexing="ij")
|
||||||
|
u = (X - 15) ** 2 + (Y - 15) ** 2 + (Z - 15) ** 2 - 8**2
|
||||||
|
u = u[None].float()
|
||||||
|
verts, faces = marching_cubes_naive(u, 0, return_local_coords=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def marching_cubes_with_init(batch_size: int, V: int):
|
def marching_cubes_with_init(algo_type: str, batch_size: int, V: int):
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
volume_data = torch.rand(
|
volume_data = torch.rand(
|
||||||
(batch_size, V, V, V), dtype=torch.float32, device=device
|
(batch_size, V, V, V), dtype=torch.float32, device=device
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
algo_table = {
|
||||||
|
"naive": marching_cubes_naive,
|
||||||
|
}
|
||||||
|
|
||||||
def convert():
|
def convert():
|
||||||
marching_cubes_naive(volume_data, return_local_coords=False)
|
algo_table[algo_type](volume_data, return_local_coords=False)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
return convert
|
return convert
|
||||||
|
Loading…
x
Reference in New Issue
Block a user