vert_align for Pointclouds object

Reviewed By: gkioxari

Differential Revision: D21088730

fbshipit-source-id: f8c125ac8c8009d45712ae63237ca64acf1faf45
This commit is contained in:
Jeremy Reizenstein
2020-04-17 10:35:45 -07:00
committed by Facebook GitHub Bot
parent e19df58766
commit f25af96959
2 changed files with 61 additions and 20 deletions

View File

@@ -8,12 +8,13 @@ import torch.nn.functional as F
from common_testing import TestCaseMixin
from pytorch3d.ops.vert_align import vert_align
from pytorch3d.structures.meshes import Meshes
from pytorch3d.structures.pointclouds import Pointclouds
class TestVertAlign(TestCaseMixin, unittest.TestCase):
@staticmethod
def vert_align_naive(
feats, verts_or_meshes, return_packed: bool = False, align_corners: bool = True
feats, verts, return_packed: bool = False, align_corners: bool = True
):
"""
Naive implementation of vert_align.
@@ -28,12 +29,12 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
out_i_feats = []
for feat in feats:
feats_i = feat[i][None, :, :, :] # (1, C, H, W)
if torch.is_tensor(verts_or_meshes):
grid = verts_or_meshes[i][None, None, :, :2] # (1, 1, V, 2)
elif hasattr(verts_or_meshes, "verts_list"):
grid = verts_or_meshes.verts_list()[i][
None, None, :, :2
] # (1, 1, V, 2)
if torch.is_tensor(verts):
grid = verts[i][None, None, :, :2] # (1, 1, V, 2)
elif hasattr(verts, "verts_list"):
grid = verts.verts_list()[i][None, None, :, :2] # (1, 1, V, 2)
elif hasattr(verts, "points_list"):
grid = verts.points_list()[i][None, None, :, :2] # (1, 1, V, 2)
else:
raise ValueError("verts_or_meshes is invalid")
feat_sampled_i = F.grid_sample(
@@ -56,7 +57,9 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
return out_feats
@staticmethod
def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000):
def init_meshes(
num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000
) -> Meshes:
device = torch.device("cuda:0")
verts_list = []
faces_list = []
@@ -74,6 +77,20 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
return meshes
@staticmethod
def init_pointclouds(num_clouds: int = 10, num_points: int = 1000) -> Pointclouds:
device = torch.device("cuda:0")
points_list = []
for _ in range(num_clouds):
points = (
torch.rand((num_points, 3), dtype=torch.float32, device=device) * 2.0
- 1.0
) # points in the space of [-1, 1]
points_list.append(points)
pointclouds = Pointclouds(points=points_list)
return pointclouds
@staticmethod
def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"):
H, W = [14, 28], [14, 28]
@@ -99,6 +116,27 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
naive_out = TestVertAlign.vert_align_naive(feats[0], meshes, return_packed=True)
self.assertClose(out, naive_out)
def test_vert_align_with_pointclouds(self):
"""
Test vert align vs naive implementation with meshes.
"""
pointclouds = TestVertAlign.init_pointclouds(10, 1000)
feats = TestVertAlign.init_feats(10, 256)
# feats in list
out = vert_align(feats, pointclouds, return_packed=True)
naive_out = TestVertAlign.vert_align_naive(
feats, pointclouds, return_packed=True
)
self.assertClose(out, naive_out)
# feats as tensor
out = vert_align(feats[0], pointclouds, return_packed=True)
naive_out = TestVertAlign.vert_align_naive(
feats[0], pointclouds, return_packed=True
)
self.assertClose(out, naive_out)
def test_vert_align_with_verts(self):
"""
Test vert align vs naive implementation with verts as tensor.