vert_align for Pointclouds object

Reviewed By: gkioxari

Differential Revision: D21088730

fbshipit-source-id: f8c125ac8c8009d45712ae63237ca64acf1faf45
This commit is contained in:
Jeremy Reizenstein 2020-04-17 10:35:45 -07:00 committed by Facebook GitHub Bot
parent e19df58766
commit f25af96959
2 changed files with 61 additions and 20 deletions

View File

@ -25,10 +25,10 @@ def vert_align(
feats: FloatTensor of shape (N, C, H, W) representing image features feats: FloatTensor of shape (N, C, H, W) representing image features
from which to sample or a list of features each with potentially from which to sample or a list of features each with potentially
different C, H or W dimensions. different C, H or W dimensions.
verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes) with verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes or Pointclouds)
'verts_padded' as an attribute giving the (x, y, z) vertex positions with `verts_padded' or `points_padded' as an attribute giving the (x, y, z)
for which to sample. (x, y) verts should be normalized such that vertex positions for which to sample. (x, y) verts should be normalized such
(-1, -1) corresponds to top-left and (+1, +1) to bottom-right that (-1, -1) corresponds to top-left and (+1, +1) to bottom-right
location in the input feature map. location in the input feature map.
return_packed: (bool) Indicates whether to return packed features return_packed: (bool) Indicates whether to return packed features
interp_mode: (str) Specifies how to interpolate features. interp_mode: (str) Specifies how to interpolate features.
@ -44,13 +44,11 @@ def vert_align(
resolution agnostic. Default: ``True`` resolution agnostic. Default: ``True``
Returns: Returns:
feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for each
each vertex. If feats is a list, we return concatentated vertex. If feats is a list, we return concatentated features in axis=2 of
features in axis=2 of shape (N, V, sum(C_n)) where shape (N, V, sum(C_n)) where C_n = feats[n].shape[1].
C_n = feats[n].shape[1]. If return_packed = True, the If return_packed = True, the features are transformed to a packed
features are transformed to a packed representation representation of shape (sum(V), C)
of shape (sum(V), C)
""" """
if torch.is_tensor(verts): if torch.is_tensor(verts):
if verts.dim() != 3: if verts.dim() != 3:
@ -58,8 +56,13 @@ def vert_align(
grid = verts grid = verts
elif hasattr(verts, "verts_padded"): elif hasattr(verts, "verts_padded"):
grid = verts.verts_padded() grid = verts.verts_padded()
elif hasattr(verts, "points_padded"):
grid = verts.points_padded()
else: else:
raise ValueError("verts must be a tensor or have a `verts_padded` attribute") raise ValueError(
"verts must be a tensor or have a "
+ "`points_padded' or`verts_padded` attribute."
)
grid = grid[:, None, :, :2] # (N, 1, V, 2) grid = grid[:, None, :, :2] # (N, 1, V, 2)

View File

@ -8,12 +8,13 @@ import torch.nn.functional as F
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.ops.vert_align import vert_align from pytorch3d.ops.vert_align import vert_align
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from pytorch3d.structures.pointclouds import Pointclouds
class TestVertAlign(TestCaseMixin, unittest.TestCase): class TestVertAlign(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def vert_align_naive( def vert_align_naive(
feats, verts_or_meshes, return_packed: bool = False, align_corners: bool = True feats, verts, return_packed: bool = False, align_corners: bool = True
): ):
""" """
Naive implementation of vert_align. Naive implementation of vert_align.
@ -28,12 +29,12 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
out_i_feats = [] out_i_feats = []
for feat in feats: for feat in feats:
feats_i = feat[i][None, :, :, :] # (1, C, H, W) feats_i = feat[i][None, :, :, :] # (1, C, H, W)
if torch.is_tensor(verts_or_meshes): if torch.is_tensor(verts):
grid = verts_or_meshes[i][None, None, :, :2] # (1, 1, V, 2) grid = verts[i][None, None, :, :2] # (1, 1, V, 2)
elif hasattr(verts_or_meshes, "verts_list"): elif hasattr(verts, "verts_list"):
grid = verts_or_meshes.verts_list()[i][ grid = verts.verts_list()[i][None, None, :, :2] # (1, 1, V, 2)
None, None, :, :2 elif hasattr(verts, "points_list"):
] # (1, 1, V, 2) grid = verts.points_list()[i][None, None, :, :2] # (1, 1, V, 2)
else: else:
raise ValueError("verts_or_meshes is invalid") raise ValueError("verts_or_meshes is invalid")
feat_sampled_i = F.grid_sample( feat_sampled_i = F.grid_sample(
@ -56,7 +57,9 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
return out_feats return out_feats
@staticmethod @staticmethod
def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000): def init_meshes(
num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000
) -> Meshes:
device = torch.device("cuda:0") device = torch.device("cuda:0")
verts_list = [] verts_list = []
faces_list = [] faces_list = []
@ -74,6 +77,20 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
return meshes return meshes
@staticmethod
def init_pointclouds(num_clouds: int = 10, num_points: int = 1000) -> Pointclouds:
device = torch.device("cuda:0")
points_list = []
for _ in range(num_clouds):
points = (
torch.rand((num_points, 3), dtype=torch.float32, device=device) * 2.0
- 1.0
) # points in the space of [-1, 1]
points_list.append(points)
pointclouds = Pointclouds(points=points_list)
return pointclouds
@staticmethod @staticmethod
def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"): def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"):
H, W = [14, 28], [14, 28] H, W = [14, 28], [14, 28]
@ -99,6 +116,27 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
naive_out = TestVertAlign.vert_align_naive(feats[0], meshes, return_packed=True) naive_out = TestVertAlign.vert_align_naive(feats[0], meshes, return_packed=True)
self.assertClose(out, naive_out) self.assertClose(out, naive_out)
def test_vert_align_with_pointclouds(self):
"""
Test vert align vs naive implementation with meshes.
"""
pointclouds = TestVertAlign.init_pointclouds(10, 1000)
feats = TestVertAlign.init_feats(10, 256)
# feats in list
out = vert_align(feats, pointclouds, return_packed=True)
naive_out = TestVertAlign.vert_align_naive(
feats, pointclouds, return_packed=True
)
self.assertClose(out, naive_out)
# feats as tensor
out = vert_align(feats[0], pointclouds, return_packed=True)
naive_out = TestVertAlign.vert_align_naive(
feats[0], pointclouds, return_packed=True
)
self.assertClose(out, naive_out)
def test_vert_align_with_verts(self): def test_vert_align_with_verts(self):
""" """
Test vert align vs naive implementation with verts as tensor. Test vert align vs naive implementation with verts as tensor.