mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
add align modes for cubify
Summary: Add alignment modes for cubify operation. Reviewed By: nikhilaravi Differential Revision: D21393199 fbshipit-source-id: 7022044e591229a6ed5efc361fd3215e65f43f86
This commit is contained in:
parent
8fc28baa27
commit
a61c9376d5
@ -45,7 +45,7 @@ def ravel_index(idx, dims) -> torch.Tensor:
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def cubify(voxels, thresh, device=None) -> Meshes:
|
||||
def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
|
||||
r"""
|
||||
Converts a voxel to a mesh by replacing each occupied voxel with a cube
|
||||
consisting of 12 faces and 8 vertices. Shared vertices are merged, and
|
||||
@ -54,13 +54,38 @@ def cubify(voxels, thresh, device=None) -> Meshes:
|
||||
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
|
||||
thresh: A scalar threshold. If a voxel occupancy is larger than
|
||||
thresh, the voxel is considered occupied.
|
||||
device: The device of the output meshes
|
||||
align: Defines the alignment of the mesh vertices and the grid locations.
|
||||
Has to be one of {"topleft", "corner", "center"}. See below for explanation.
|
||||
Default is "topleft".
|
||||
Returns:
|
||||
meshes: A Meshes object of the corresponding meshes.
|
||||
|
||||
|
||||
The alignment between the vertices of the cubified mesh and the voxel locations (or pixels)
|
||||
is defined by the choice of `align`. We support three modes, as shown below for a 2x2 grid:
|
||||
|
||||
X---X---- X-------X ---------
|
||||
| | | | | | | X | X |
|
||||
X---X---- --------- ---------
|
||||
| | | | | | | X | X |
|
||||
--------- X-------X ---------
|
||||
|
||||
topleft corner center
|
||||
|
||||
In the figure, X denote the grid locations and the squares represent the added cuboids.
|
||||
When `align="topleft"`, then the top left corner of each cuboid corresponds to the
|
||||
pixel coordinate of the input grid.
|
||||
When `align="corner"`, then the corners of the output mesh span the whole grid.
|
||||
When `align="center"`, then the grid locations form the center of the cuboids.
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = voxels.device
|
||||
|
||||
if align not in ["topleft", "corner", "center"]:
|
||||
raise ValueError("Align mode must be one of (topleft, corner, center).")
|
||||
|
||||
if len(voxels) == 0:
|
||||
return Meshes(verts=[], faces=[])
|
||||
|
||||
@ -146,7 +171,7 @@ def cubify(voxels, thresh, device=None) -> Meshes:
|
||||
|
||||
# boolean to linear index
|
||||
# NF x 2
|
||||
linind = torch.nonzero(faces_idx)
|
||||
linind = torch.nonzero(faces_idx, as_tuple=False)
|
||||
# NF x 4
|
||||
nyxz = unravel_index(linind[:, 0], (N, H, W, D))
|
||||
|
||||
@ -170,11 +195,19 @@ def cubify(voxels, thresh, device=None) -> Meshes:
|
||||
torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
|
||||
)
|
||||
y = y.to(device=device, dtype=torch.float32)
|
||||
y = y * 2.0 / (H - 1.0) - 1.0
|
||||
x = x.to(device=device, dtype=torch.float32)
|
||||
x = x * 2.0 / (W - 1.0) - 1.0
|
||||
z = z.to(device=device, dtype=torch.float32)
|
||||
z = z * 2.0 / (D - 1.0) - 1.0
|
||||
|
||||
if align == "center":
|
||||
x = x - 0.5
|
||||
y = y - 0.5
|
||||
z = z - 0.5
|
||||
|
||||
margin = 0.0 if align == "corner" else 1.0
|
||||
y = y * 2.0 / (H - margin) - 1.0
|
||||
x = x * 2.0 / (W - margin) - 1.0
|
||||
z = z * 2.0 / (D - margin) - 1.0
|
||||
|
||||
# ((H+1)(W+1)(D+1)) x 3
|
||||
grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3)
|
||||
|
||||
@ -196,7 +229,7 @@ def cubify(voxels, thresh, device=None) -> Meshes:
|
||||
idlenum = idleverts.cumsum(1)
|
||||
|
||||
verts_list = [
|
||||
grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0])
|
||||
grid_verts.index_select(0, (idleverts[n] == 0).nonzero(as_tuple=False)[:, 0])
|
||||
for n in range(N)
|
||||
]
|
||||
faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]
|
||||
|
@ -3,16 +3,17 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops import cubify
|
||||
|
||||
|
||||
class TestCubify(unittest.TestCase):
|
||||
class TestCubify(TestCaseMixin, unittest.TestCase):
|
||||
def test_allempty(self):
|
||||
N, V = 32, 14
|
||||
device = torch.device("cuda:0")
|
||||
voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device)
|
||||
meshes = cubify(voxels, 0.5, 0)
|
||||
self.assertTrue(meshes.isempty)
|
||||
meshes = cubify(voxels, 0.5)
|
||||
self.assertTrue(meshes.isempty())
|
||||
|
||||
def test_cubify(self):
|
||||
N, V = 4, 2
|
||||
@ -29,155 +30,143 @@ class TestCubify(unittest.TestCase):
|
||||
voxels[3, 1, 1, 0] = 1.0
|
||||
|
||||
# compute cubify
|
||||
meshes = cubify(voxels, 0.5, 0)
|
||||
meshes = cubify(voxels, 0.5)
|
||||
|
||||
# 1st-check
|
||||
verts, faces = meshes.get_mesh_verts_faces(0)
|
||||
self.assertTrue(
|
||||
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1]))
|
||||
self.assertClose(faces.max().cpu(), torch.tensor(verts.size(0) - 1))
|
||||
self.assertClose(
|
||||
verts,
|
||||
torch.tensor(
|
||||
[
|
||||
[-1.0, -1.0, -1.0],
|
||||
[-1.0, -1.0, 1.0],
|
||||
[1.0, -1.0, -1.0],
|
||||
[1.0, -1.0, 1.0],
|
||||
[-1.0, 1.0, -1.0],
|
||||
[-1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, -1.0],
|
||||
[1.0, 1.0, 1.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
verts,
|
||||
torch.tensor(
|
||||
[
|
||||
[-1.0, -1.0, -1.0],
|
||||
[-1.0, -1.0, 1.0],
|
||||
[1.0, -1.0, -1.0],
|
||||
[1.0, -1.0, 1.0],
|
||||
[-1.0, 1.0, -1.0],
|
||||
[-1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, -1.0],
|
||||
[1.0, 1.0, 1.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
faces,
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 1, 4],
|
||||
[1, 5, 4],
|
||||
[4, 5, 6],
|
||||
[5, 7, 6],
|
||||
[0, 4, 6],
|
||||
[0, 6, 2],
|
||||
[0, 3, 1],
|
||||
[0, 2, 3],
|
||||
[6, 7, 3],
|
||||
[6, 3, 2],
|
||||
[1, 7, 5],
|
||||
[1, 3, 7],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
self.assertClose(
|
||||
faces,
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 1, 4],
|
||||
[1, 5, 4],
|
||||
[4, 5, 6],
|
||||
[5, 7, 6],
|
||||
[0, 4, 6],
|
||||
[0, 6, 2],
|
||||
[0, 3, 1],
|
||||
[0, 2, 3],
|
||||
[6, 7, 3],
|
||||
[6, 3, 2],
|
||||
[1, 7, 5],
|
||||
[1, 3, 7],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
# 2nd-check
|
||||
verts, faces = meshes.get_mesh_verts_faces(1)
|
||||
self.assertTrue(
|
||||
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1]))
|
||||
self.assertClose(faces.max().cpu(), torch.tensor(verts.size(0) - 1))
|
||||
self.assertClose(
|
||||
verts,
|
||||
torch.tensor(
|
||||
[
|
||||
[-1.0, -1.0, -1.0],
|
||||
[-1.0, -1.0, 1.0],
|
||||
[-1.0, -1.0, 3.0],
|
||||
[1.0, -1.0, -1.0],
|
||||
[1.0, -1.0, 1.0],
|
||||
[1.0, -1.0, 3.0],
|
||||
[3.0, -1.0, -1.0],
|
||||
[3.0, -1.0, 1.0],
|
||||
[3.0, -1.0, 3.0],
|
||||
[-1.0, 1.0, -1.0],
|
||||
[-1.0, 1.0, 1.0],
|
||||
[-1.0, 1.0, 3.0],
|
||||
[1.0, 1.0, -1.0],
|
||||
[1.0, 1.0, 3.0],
|
||||
[3.0, 1.0, -1.0],
|
||||
[3.0, 1.0, 1.0],
|
||||
[3.0, 1.0, 3.0],
|
||||
[-1.0, 3.0, -1.0],
|
||||
[-1.0, 3.0, 1.0],
|
||||
[-1.0, 3.0, 3.0],
|
||||
[1.0, 3.0, -1.0],
|
||||
[1.0, 3.0, 1.0],
|
||||
[1.0, 3.0, 3.0],
|
||||
[3.0, 3.0, -1.0],
|
||||
[3.0, 3.0, 1.0],
|
||||
[3.0, 3.0, 3.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
verts,
|
||||
torch.tensor(
|
||||
[
|
||||
[-1.0, -1.0, -1.0],
|
||||
[-1.0, -1.0, 1.0],
|
||||
[-1.0, -1.0, 3.0],
|
||||
[1.0, -1.0, -1.0],
|
||||
[1.0, -1.0, 1.0],
|
||||
[1.0, -1.0, 3.0],
|
||||
[3.0, -1.0, -1.0],
|
||||
[3.0, -1.0, 1.0],
|
||||
[3.0, -1.0, 3.0],
|
||||
[-1.0, 1.0, -1.0],
|
||||
[-1.0, 1.0, 1.0],
|
||||
[-1.0, 1.0, 3.0],
|
||||
[1.0, 1.0, -1.0],
|
||||
[1.0, 1.0, 3.0],
|
||||
[3.0, 1.0, -1.0],
|
||||
[3.0, 1.0, 1.0],
|
||||
[3.0, 1.0, 3.0],
|
||||
[-1.0, 3.0, -1.0],
|
||||
[-1.0, 3.0, 1.0],
|
||||
[-1.0, 3.0, 3.0],
|
||||
[1.0, 3.0, -1.0],
|
||||
[1.0, 3.0, 1.0],
|
||||
[1.0, 3.0, 3.0],
|
||||
[3.0, 3.0, -1.0],
|
||||
[3.0, 3.0, 1.0],
|
||||
[3.0, 3.0, 3.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
faces,
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 1, 9],
|
||||
[1, 10, 9],
|
||||
[0, 9, 12],
|
||||
[0, 12, 3],
|
||||
[0, 4, 1],
|
||||
[0, 3, 4],
|
||||
[1, 2, 10],
|
||||
[2, 11, 10],
|
||||
[1, 5, 2],
|
||||
[1, 4, 5],
|
||||
[2, 13, 11],
|
||||
[2, 5, 13],
|
||||
[3, 12, 14],
|
||||
[3, 14, 6],
|
||||
[3, 7, 4],
|
||||
[3, 6, 7],
|
||||
[14, 15, 7],
|
||||
[14, 7, 6],
|
||||
[4, 8, 5],
|
||||
[4, 7, 8],
|
||||
[15, 16, 8],
|
||||
[15, 8, 7],
|
||||
[5, 16, 13],
|
||||
[5, 8, 16],
|
||||
[9, 10, 17],
|
||||
[10, 18, 17],
|
||||
[17, 18, 20],
|
||||
[18, 21, 20],
|
||||
[9, 17, 20],
|
||||
[9, 20, 12],
|
||||
[10, 11, 18],
|
||||
[11, 19, 18],
|
||||
[18, 19, 21],
|
||||
[19, 22, 21],
|
||||
[11, 22, 19],
|
||||
[11, 13, 22],
|
||||
[20, 21, 23],
|
||||
[21, 24, 23],
|
||||
[12, 20, 23],
|
||||
[12, 23, 14],
|
||||
[23, 24, 15],
|
||||
[23, 15, 14],
|
||||
[21, 22, 24],
|
||||
[22, 25, 24],
|
||||
[24, 25, 16],
|
||||
[24, 16, 15],
|
||||
[13, 25, 22],
|
||||
[13, 16, 25],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
self.assertClose(
|
||||
faces,
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 1, 9],
|
||||
[1, 10, 9],
|
||||
[0, 9, 12],
|
||||
[0, 12, 3],
|
||||
[0, 4, 1],
|
||||
[0, 3, 4],
|
||||
[1, 2, 10],
|
||||
[2, 11, 10],
|
||||
[1, 5, 2],
|
||||
[1, 4, 5],
|
||||
[2, 13, 11],
|
||||
[2, 5, 13],
|
||||
[3, 12, 14],
|
||||
[3, 14, 6],
|
||||
[3, 7, 4],
|
||||
[3, 6, 7],
|
||||
[14, 15, 7],
|
||||
[14, 7, 6],
|
||||
[4, 8, 5],
|
||||
[4, 7, 8],
|
||||
[15, 16, 8],
|
||||
[15, 8, 7],
|
||||
[5, 16, 13],
|
||||
[5, 8, 16],
|
||||
[9, 10, 17],
|
||||
[10, 18, 17],
|
||||
[17, 18, 20],
|
||||
[18, 21, 20],
|
||||
[9, 17, 20],
|
||||
[9, 20, 12],
|
||||
[10, 11, 18],
|
||||
[11, 19, 18],
|
||||
[18, 19, 21],
|
||||
[19, 22, 21],
|
||||
[11, 22, 19],
|
||||
[11, 13, 22],
|
||||
[20, 21, 23],
|
||||
[21, 24, 23],
|
||||
[12, 20, 23],
|
||||
[12, 23, 14],
|
||||
[23, 24, 15],
|
||||
[23, 15, 14],
|
||||
[21, 22, 24],
|
||||
[22, 25, 24],
|
||||
[24, 25, 16],
|
||||
[24, 16, 15],
|
||||
[13, 25, 22],
|
||||
[13, 16, 25],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
|
||||
# 3rd-check
|
||||
@ -187,91 +176,127 @@ class TestCubify(unittest.TestCase):
|
||||
|
||||
# 4th-check
|
||||
verts, faces = meshes.get_mesh_verts_faces(3)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
verts,
|
||||
torch.tensor(
|
||||
[
|
||||
[1.0, -1.0, -1.0],
|
||||
[1.0, -1.0, 1.0],
|
||||
[1.0, -1.0, 3.0],
|
||||
[3.0, -1.0, -1.0],
|
||||
[3.0, -1.0, 1.0],
|
||||
[3.0, -1.0, 3.0],
|
||||
[-1.0, 1.0, 1.0],
|
||||
[-1.0, 1.0, 3.0],
|
||||
[1.0, 1.0, -1.0],
|
||||
[1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 3.0],
|
||||
[3.0, 1.0, -1.0],
|
||||
[3.0, 1.0, 1.0],
|
||||
[3.0, 1.0, 3.0],
|
||||
[-1.0, 3.0, 1.0],
|
||||
[-1.0, 3.0, 3.0],
|
||||
[1.0, 3.0, -1.0],
|
||||
[1.0, 3.0, 1.0],
|
||||
[1.0, 3.0, 3.0],
|
||||
[3.0, 3.0, -1.0],
|
||||
[3.0, 3.0, 1.0],
|
||||
[3.0, 3.0, 3.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
self.assertClose(
|
||||
verts,
|
||||
torch.tensor(
|
||||
[
|
||||
[1.0, -1.0, -1.0],
|
||||
[1.0, -1.0, 1.0],
|
||||
[1.0, -1.0, 3.0],
|
||||
[3.0, -1.0, -1.0],
|
||||
[3.0, -1.0, 1.0],
|
||||
[3.0, -1.0, 3.0],
|
||||
[-1.0, 1.0, 1.0],
|
||||
[-1.0, 1.0, 3.0],
|
||||
[1.0, 1.0, -1.0],
|
||||
[1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 3.0],
|
||||
[3.0, 1.0, -1.0],
|
||||
[3.0, 1.0, 1.0],
|
||||
[3.0, 1.0, 3.0],
|
||||
[-1.0, 3.0, 1.0],
|
||||
[-1.0, 3.0, 3.0],
|
||||
[1.0, 3.0, -1.0],
|
||||
[1.0, 3.0, 1.0],
|
||||
[1.0, 3.0, 3.0],
|
||||
[3.0, 3.0, -1.0],
|
||||
[3.0, 3.0, 1.0],
|
||||
[3.0, 3.0, 3.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
faces,
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 1, 8],
|
||||
[1, 9, 8],
|
||||
[0, 8, 11],
|
||||
[0, 11, 3],
|
||||
[0, 4, 1],
|
||||
[0, 3, 4],
|
||||
[11, 12, 4],
|
||||
[11, 4, 3],
|
||||
[1, 2, 9],
|
||||
[2, 10, 9],
|
||||
[1, 5, 2],
|
||||
[1, 4, 5],
|
||||
[12, 13, 5],
|
||||
[12, 5, 4],
|
||||
[2, 13, 10],
|
||||
[2, 5, 13],
|
||||
[6, 7, 14],
|
||||
[7, 15, 14],
|
||||
[14, 15, 17],
|
||||
[15, 18, 17],
|
||||
[6, 14, 17],
|
||||
[6, 17, 9],
|
||||
[6, 10, 7],
|
||||
[6, 9, 10],
|
||||
[7, 18, 15],
|
||||
[7, 10, 18],
|
||||
[8, 9, 16],
|
||||
[9, 17, 16],
|
||||
[16, 17, 19],
|
||||
[17, 20, 19],
|
||||
[8, 16, 19],
|
||||
[8, 19, 11],
|
||||
[19, 20, 12],
|
||||
[19, 12, 11],
|
||||
[17, 18, 20],
|
||||
[18, 21, 20],
|
||||
[20, 21, 13],
|
||||
[20, 13, 12],
|
||||
[10, 21, 18],
|
||||
[10, 13, 21],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
self.assertClose(
|
||||
faces,
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 1, 8],
|
||||
[1, 9, 8],
|
||||
[0, 8, 11],
|
||||
[0, 11, 3],
|
||||
[0, 4, 1],
|
||||
[0, 3, 4],
|
||||
[11, 12, 4],
|
||||
[11, 4, 3],
|
||||
[1, 2, 9],
|
||||
[2, 10, 9],
|
||||
[1, 5, 2],
|
||||
[1, 4, 5],
|
||||
[12, 13, 5],
|
||||
[12, 5, 4],
|
||||
[2, 13, 10],
|
||||
[2, 5, 13],
|
||||
[6, 7, 14],
|
||||
[7, 15, 14],
|
||||
[14, 15, 17],
|
||||
[15, 18, 17],
|
||||
[6, 14, 17],
|
||||
[6, 17, 9],
|
||||
[6, 10, 7],
|
||||
[6, 9, 10],
|
||||
[7, 18, 15],
|
||||
[7, 10, 18],
|
||||
[8, 9, 16],
|
||||
[9, 17, 16],
|
||||
[16, 17, 19],
|
||||
[17, 20, 19],
|
||||
[8, 16, 19],
|
||||
[8, 19, 11],
|
||||
[19, 20, 12],
|
||||
[19, 12, 11],
|
||||
[17, 18, 20],
|
||||
[18, 21, 20],
|
||||
[20, 21, 13],
|
||||
[20, 13, 12],
|
||||
[10, 21, 18],
|
||||
[10, 13, 21],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
|
||||
def test_align(self):
|
||||
N, V = 1, 2
|
||||
device = torch.device("cuda:0")
|
||||
voxels = torch.ones((N, V, V, V), dtype=torch.float32, device=device)
|
||||
|
||||
# topleft align
|
||||
mesh = cubify(voxels, 0.5)
|
||||
verts, faces = mesh.get_mesh_verts_faces(0)
|
||||
self.assertClose(verts.min(), torch.tensor(-1.0, device=device))
|
||||
self.assertClose(verts.max(), torch.tensor(3.0, device=device))
|
||||
|
||||
# corner align
|
||||
mesh = cubify(voxels, 0.5, align="corner")
|
||||
verts, faces = mesh.get_mesh_verts_faces(0)
|
||||
self.assertClose(verts.min(), torch.tensor(-1.0, device=device))
|
||||
self.assertClose(verts.max(), torch.tensor(1.0, device=device))
|
||||
|
||||
# center align
|
||||
mesh = cubify(voxels, 0.5, align="center")
|
||||
verts, faces = mesh.get_mesh_verts_faces(0)
|
||||
self.assertClose(verts.min(), torch.tensor(-2.0, device=device))
|
||||
self.assertClose(verts.max(), torch.tensor(2.0, device=device))
|
||||
|
||||
# invalid align
|
||||
with self.assertRaisesRegex(ValueError, "Align mode must be one of"):
|
||||
cubify(voxels, 0.5, align="")
|
||||
|
||||
# invalid align
|
||||
with self.assertRaisesRegex(ValueError, "Align mode must be one of"):
|
||||
cubify(voxels, 0.5, align="topright")
|
||||
|
||||
# inside occupancy, similar to GH#185 use case
|
||||
N, V = 1, 4
|
||||
voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device)
|
||||
voxels[0, : V // 2, : V // 2, : V // 2] = 1.0
|
||||
mesh = cubify(voxels, 0.5, align="corner")
|
||||
verts, faces = mesh.get_mesh_verts_faces(0)
|
||||
self.assertClose(verts.min(), torch.tensor(-1.0, device=device))
|
||||
self.assertClose(verts.max(), torch.tensor(0.0, device=device))
|
||||
|
||||
@staticmethod
|
||||
def cubify_with_init(batch_size: int, V: int):
|
||||
device = torch.device("cuda:0")
|
||||
|
Loading…
x
Reference in New Issue
Block a user