Marching Cubes C++ torch extension

Summary:
Torch C++ extension for Marching Cubes

- Add torch C++ extension for marching cubes. Observe a speed up of ~255x-324x speed up (over varying batch sizes and spatial resolutions)

- Add C++ impl in existing unit-tests.

(Note: this ignores all push blocking failures!)

Reviewed By: kjchalup

Differential Revision: D39590638

fbshipit-source-id: e44d2852a24c2c398e5ea9db20f0dfaa1817e457
This commit is contained in:
Jiali Duan
2022-10-06 11:13:53 -07:00
committed by Facebook GitHub Bot
parent 850efdf706
commit 0d8608b9f9
7 changed files with 879 additions and 9 deletions

View File

@@ -7,8 +7,10 @@
from typing import List, Optional, Tuple
import torch
from pytorch3d import _C
from pytorch3d.ops.marching_cubes_data import EDGE_TO_VERTICES, FACE_TABLE, INDEX
from pytorch3d.transforms import Translate
from torch.autograd import Function
EPS = 0.00001
@@ -225,3 +227,71 @@ def marching_cubes_naive(
batched_verts.append([])
batched_faces.append([])
return batched_verts, batched_faces
########################################
# Marching Cubes Implementation in C++
########################################
class _marching_cubes(Function):
"""
Torch Function wrapper for marching_cubes C++ implementation
Backward is not supported.
"""
@staticmethod
def forward(ctx, vol, isolevel):
verts, faces = _C.marching_cubes(vol, isolevel)
return verts, faces
@staticmethod
def backward(ctx, grad_verts, grad_faces):
raise ValueError("marching_cubes backward is not supported")
def marching_cubes(
vol_batch: torch.Tensor,
isolevel: Optional[float] = None,
return_local_coords: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Run marching cubes over a volume scalar field with a designated isolevel.
Returns vertices and faces of the obtained mesh.
This operation is non-differentiable.
Args:
vol_batch: a Tensor of size (N, D, H, W) corresponding to
a batch of 3D scalar fields
isolevel: float used as threshold to determine if a point is inside/outside
the volume. If None, then the average of the maximum and minimum value
of the scalar field is used.
return_local_coords: bool. If True the output vertices will be in local coordinates in
the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range
[0, W-1] x [0, H-1] x [0, D-1]
Returns:
verts: [{V_0}, {V_1}, ...] List of N sets of vertices of shape (|V_i|, 3) in FloatTensor
faces: [{F_0}, {F_1}, ...] List of N sets of faces of shape (|F_i|, 3) in LongTensors
"""
batched_verts, batched_faces = [], []
D, H, W = vol_batch.shape[1:]
for i in range(len(vol_batch)):
vol = vol_batch[i]
thresh = ((vol.max() + vol.min()) / 2).item() if isolevel is None else isolevel
# pyre-fixme[16]: `_marching_cubes` has no attribute `apply`.
verts, faces = _marching_cubes.apply(vol, thresh)
if len(faces) > 0 and len(verts) > 0:
# Convert from world coordinates ([0, D-1], [0, H-1], [0, W-1]) to
# local coordinates in the range [-1, 1]
if return_local_coords:
verts = (
Translate(x=+1.0, y=+1.0, z=+1.0, device=vol.device)
.scale((vol.new_tensor([W, H, D])[None] - 1) * 0.5)
.inverse()
).transform_points(verts[None])[0]
batched_verts.append(verts)
batched_faces.append(faces)
else:
batched_verts.append([])
batched_faces.append([])
return batched_verts, batched_faces