mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
Marching Cubes C++ torch extension
Summary: Torch C++ extension for Marching Cubes - Add torch C++ extension for marching cubes. Observe a speed up of ~255x-324x speed up (over varying batch sizes and spatial resolutions) - Add C++ impl in existing unit-tests. (Note: this ignores all push blocking failures!) Reviewed By: kjchalup Differential Revision: D39590638 fbshipit-source-id: e44d2852a24c2c398e5ea9db20f0dfaa1817e457
This commit is contained in:
committed by
Facebook GitHub Bot
parent
850efdf706
commit
0d8608b9f9
@@ -7,8 +7,10 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.ops.marching_cubes_data import EDGE_TO_VERTICES, FACE_TABLE, INDEX
|
||||
from pytorch3d.transforms import Translate
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
EPS = 0.00001
|
||||
@@ -225,3 +227,71 @@ def marching_cubes_naive(
|
||||
batched_verts.append([])
|
||||
batched_faces.append([])
|
||||
return batched_verts, batched_faces
|
||||
|
||||
|
||||
########################################
|
||||
# Marching Cubes Implementation in C++
|
||||
########################################
|
||||
class _marching_cubes(Function):
|
||||
"""
|
||||
Torch Function wrapper for marching_cubes C++ implementation
|
||||
Backward is not supported.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vol, isolevel):
|
||||
verts, faces = _C.marching_cubes(vol, isolevel)
|
||||
return verts, faces
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_verts, grad_faces):
|
||||
raise ValueError("marching_cubes backward is not supported")
|
||||
|
||||
|
||||
def marching_cubes(
|
||||
vol_batch: torch.Tensor,
|
||||
isolevel: Optional[float] = None,
|
||||
return_local_coords: bool = True,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
"""
|
||||
Run marching cubes over a volume scalar field with a designated isolevel.
|
||||
Returns vertices and faces of the obtained mesh.
|
||||
This operation is non-differentiable.
|
||||
|
||||
Args:
|
||||
vol_batch: a Tensor of size (N, D, H, W) corresponding to
|
||||
a batch of 3D scalar fields
|
||||
isolevel: float used as threshold to determine if a point is inside/outside
|
||||
the volume. If None, then the average of the maximum and minimum value
|
||||
of the scalar field is used.
|
||||
return_local_coords: bool. If True the output vertices will be in local coordinates in
|
||||
the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range
|
||||
[0, W-1] x [0, H-1] x [0, D-1]
|
||||
|
||||
|
||||
Returns:
|
||||
verts: [{V_0}, {V_1}, ...] List of N sets of vertices of shape (|V_i|, 3) in FloatTensor
|
||||
faces: [{F_0}, {F_1}, ...] List of N sets of faces of shape (|F_i|, 3) in LongTensors
|
||||
"""
|
||||
batched_verts, batched_faces = [], []
|
||||
D, H, W = vol_batch.shape[1:]
|
||||
for i in range(len(vol_batch)):
|
||||
vol = vol_batch[i]
|
||||
thresh = ((vol.max() + vol.min()) / 2).item() if isolevel is None else isolevel
|
||||
# pyre-fixme[16]: `_marching_cubes` has no attribute `apply`.
|
||||
verts, faces = _marching_cubes.apply(vol, thresh)
|
||||
if len(faces) > 0 and len(verts) > 0:
|
||||
# Convert from world coordinates ([0, D-1], [0, H-1], [0, W-1]) to
|
||||
# local coordinates in the range [-1, 1]
|
||||
if return_local_coords:
|
||||
verts = (
|
||||
Translate(x=+1.0, y=+1.0, z=+1.0, device=vol.device)
|
||||
.scale((vol.new_tensor([W, H, D])[None] - 1) * 0.5)
|
||||
.inverse()
|
||||
).transform_points(verts[None])[0]
|
||||
batched_verts.append(verts)
|
||||
batched_faces.append(faces)
|
||||
else:
|
||||
batched_verts.append([])
|
||||
batched_faces.append([])
|
||||
return batched_verts, batched_faces
|
||||
|
||||
Reference in New Issue
Block a user