diff --git a/pytorch3d/ops/vert_align.py b/pytorch3d/ops/vert_align.py index 43e50eed..1ec42053 100644 --- a/pytorch3d/ops/vert_align.py +++ b/pytorch3d/ops/vert_align.py @@ -25,10 +25,10 @@ def vert_align( feats: FloatTensor of shape (N, C, H, W) representing image features from which to sample or a list of features each with potentially different C, H or W dimensions. - verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes) with - 'verts_padded' as an attribute giving the (x, y, z) vertex positions - for which to sample. (x, y) verts should be normalized such that - (-1, -1) corresponds to top-left and (+1, +1) to bottom-right + verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes or Pointclouds) + with `verts_padded' or `points_padded' as an attribute giving the (x, y, z) + vertex positions for which to sample. (x, y) verts should be normalized such + that (-1, -1) corresponds to top-left and (+1, +1) to bottom-right location in the input feature map. return_packed: (bool) Indicates whether to return packed features interp_mode: (str) Specifies how to interpolate features. @@ -44,13 +44,11 @@ def vert_align( resolution agnostic. Default: ``True`` Returns: - feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for - each vertex. If feats is a list, we return concatentated - features in axis=2 of shape (N, V, sum(C_n)) where - C_n = feats[n].shape[1]. If return_packed = True, the - features are transformed to a packed representation - of shape (sum(V), C) - + feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for each + vertex. If feats is a list, we return concatentated features in axis=2 of + shape (N, V, sum(C_n)) where C_n = feats[n].shape[1]. + If return_packed = True, the features are transformed to a packed + representation of shape (sum(V), C) """ if torch.is_tensor(verts): if verts.dim() != 3: @@ -58,8 +56,13 @@ def vert_align( grid = verts elif hasattr(verts, "verts_padded"): grid = verts.verts_padded() + elif hasattr(verts, "points_padded"): + grid = verts.points_padded() else: - raise ValueError("verts must be a tensor or have a `verts_padded` attribute") + raise ValueError( + "verts must be a tensor or have a " + + "`points_padded' or`verts_padded` attribute." + ) grid = grid[:, None, :, :2] # (N, 1, V, 2) diff --git a/tests/test_vert_align.py b/tests/test_vert_align.py index 13935590..810e0e5a 100644 --- a/tests/test_vert_align.py +++ b/tests/test_vert_align.py @@ -8,12 +8,13 @@ import torch.nn.functional as F from common_testing import TestCaseMixin from pytorch3d.ops.vert_align import vert_align from pytorch3d.structures.meshes import Meshes +from pytorch3d.structures.pointclouds import Pointclouds class TestVertAlign(TestCaseMixin, unittest.TestCase): @staticmethod def vert_align_naive( - feats, verts_or_meshes, return_packed: bool = False, align_corners: bool = True + feats, verts, return_packed: bool = False, align_corners: bool = True ): """ Naive implementation of vert_align. @@ -28,12 +29,12 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase): out_i_feats = [] for feat in feats: feats_i = feat[i][None, :, :, :] # (1, C, H, W) - if torch.is_tensor(verts_or_meshes): - grid = verts_or_meshes[i][None, None, :, :2] # (1, 1, V, 2) - elif hasattr(verts_or_meshes, "verts_list"): - grid = verts_or_meshes.verts_list()[i][ - None, None, :, :2 - ] # (1, 1, V, 2) + if torch.is_tensor(verts): + grid = verts[i][None, None, :, :2] # (1, 1, V, 2) + elif hasattr(verts, "verts_list"): + grid = verts.verts_list()[i][None, None, :, :2] # (1, 1, V, 2) + elif hasattr(verts, "points_list"): + grid = verts.points_list()[i][None, None, :, :2] # (1, 1, V, 2) else: raise ValueError("verts_or_meshes is invalid") feat_sampled_i = F.grid_sample( @@ -56,7 +57,9 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase): return out_feats @staticmethod - def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000): + def init_meshes( + num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000 + ) -> Meshes: device = torch.device("cuda:0") verts_list = [] faces_list = [] @@ -74,6 +77,20 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase): return meshes + @staticmethod + def init_pointclouds(num_clouds: int = 10, num_points: int = 1000) -> Pointclouds: + device = torch.device("cuda:0") + points_list = [] + for _ in range(num_clouds): + points = ( + torch.rand((num_points, 3), dtype=torch.float32, device=device) * 2.0 + - 1.0 + ) # points in the space of [-1, 1] + points_list.append(points) + pointclouds = Pointclouds(points=points_list) + + return pointclouds + @staticmethod def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"): H, W = [14, 28], [14, 28] @@ -99,6 +116,27 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase): naive_out = TestVertAlign.vert_align_naive(feats[0], meshes, return_packed=True) self.assertClose(out, naive_out) + def test_vert_align_with_pointclouds(self): + """ + Test vert align vs naive implementation with meshes. + """ + pointclouds = TestVertAlign.init_pointclouds(10, 1000) + feats = TestVertAlign.init_feats(10, 256) + + # feats in list + out = vert_align(feats, pointclouds, return_packed=True) + naive_out = TestVertAlign.vert_align_naive( + feats, pointclouds, return_packed=True + ) + self.assertClose(out, naive_out) + + # feats as tensor + out = vert_align(feats[0], pointclouds, return_packed=True) + naive_out = TestVertAlign.vert_align_naive( + feats[0], pointclouds, return_packed=True + ) + self.assertClose(out, naive_out) + def test_vert_align_with_verts(self): """ Test vert align vs naive implementation with verts as tensor.