mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
vert_align for Pointclouds object
Reviewed By: gkioxari Differential Revision: D21088730 fbshipit-source-id: f8c125ac8c8009d45712ae63237ca64acf1faf45
This commit is contained in:
parent
e19df58766
commit
f25af96959
@ -25,10 +25,10 @@ def vert_align(
|
|||||||
feats: FloatTensor of shape (N, C, H, W) representing image features
|
feats: FloatTensor of shape (N, C, H, W) representing image features
|
||||||
from which to sample or a list of features each with potentially
|
from which to sample or a list of features each with potentially
|
||||||
different C, H or W dimensions.
|
different C, H or W dimensions.
|
||||||
verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes) with
|
verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes or Pointclouds)
|
||||||
'verts_padded' as an attribute giving the (x, y, z) vertex positions
|
with `verts_padded' or `points_padded' as an attribute giving the (x, y, z)
|
||||||
for which to sample. (x, y) verts should be normalized such that
|
vertex positions for which to sample. (x, y) verts should be normalized such
|
||||||
(-1, -1) corresponds to top-left and (+1, +1) to bottom-right
|
that (-1, -1) corresponds to top-left and (+1, +1) to bottom-right
|
||||||
location in the input feature map.
|
location in the input feature map.
|
||||||
return_packed: (bool) Indicates whether to return packed features
|
return_packed: (bool) Indicates whether to return packed features
|
||||||
interp_mode: (str) Specifies how to interpolate features.
|
interp_mode: (str) Specifies how to interpolate features.
|
||||||
@ -44,13 +44,11 @@ def vert_align(
|
|||||||
resolution agnostic. Default: ``True``
|
resolution agnostic. Default: ``True``
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for
|
feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for each
|
||||||
each vertex. If feats is a list, we return concatentated
|
vertex. If feats is a list, we return concatentated features in axis=2 of
|
||||||
features in axis=2 of shape (N, V, sum(C_n)) where
|
shape (N, V, sum(C_n)) where C_n = feats[n].shape[1].
|
||||||
C_n = feats[n].shape[1]. If return_packed = True, the
|
If return_packed = True, the features are transformed to a packed
|
||||||
features are transformed to a packed representation
|
representation of shape (sum(V), C)
|
||||||
of shape (sum(V), C)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if torch.is_tensor(verts):
|
if torch.is_tensor(verts):
|
||||||
if verts.dim() != 3:
|
if verts.dim() != 3:
|
||||||
@ -58,8 +56,13 @@ def vert_align(
|
|||||||
grid = verts
|
grid = verts
|
||||||
elif hasattr(verts, "verts_padded"):
|
elif hasattr(verts, "verts_padded"):
|
||||||
grid = verts.verts_padded()
|
grid = verts.verts_padded()
|
||||||
|
elif hasattr(verts, "points_padded"):
|
||||||
|
grid = verts.points_padded()
|
||||||
else:
|
else:
|
||||||
raise ValueError("verts must be a tensor or have a `verts_padded` attribute")
|
raise ValueError(
|
||||||
|
"verts must be a tensor or have a "
|
||||||
|
+ "`points_padded' or`verts_padded` attribute."
|
||||||
|
)
|
||||||
|
|
||||||
grid = grid[:, None, :, :2] # (N, 1, V, 2)
|
grid = grid[:, None, :, :2] # (N, 1, V, 2)
|
||||||
|
|
||||||
|
@ -8,12 +8,13 @@ import torch.nn.functional as F
|
|||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
from pytorch3d.ops.vert_align import vert_align
|
from pytorch3d.ops.vert_align import vert_align
|
||||||
from pytorch3d.structures.meshes import Meshes
|
from pytorch3d.structures.meshes import Meshes
|
||||||
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
|
|
||||||
|
|
||||||
class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def vert_align_naive(
|
def vert_align_naive(
|
||||||
feats, verts_or_meshes, return_packed: bool = False, align_corners: bool = True
|
feats, verts, return_packed: bool = False, align_corners: bool = True
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Naive implementation of vert_align.
|
Naive implementation of vert_align.
|
||||||
@ -28,12 +29,12 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
|||||||
out_i_feats = []
|
out_i_feats = []
|
||||||
for feat in feats:
|
for feat in feats:
|
||||||
feats_i = feat[i][None, :, :, :] # (1, C, H, W)
|
feats_i = feat[i][None, :, :, :] # (1, C, H, W)
|
||||||
if torch.is_tensor(verts_or_meshes):
|
if torch.is_tensor(verts):
|
||||||
grid = verts_or_meshes[i][None, None, :, :2] # (1, 1, V, 2)
|
grid = verts[i][None, None, :, :2] # (1, 1, V, 2)
|
||||||
elif hasattr(verts_or_meshes, "verts_list"):
|
elif hasattr(verts, "verts_list"):
|
||||||
grid = verts_or_meshes.verts_list()[i][
|
grid = verts.verts_list()[i][None, None, :, :2] # (1, 1, V, 2)
|
||||||
None, None, :, :2
|
elif hasattr(verts, "points_list"):
|
||||||
] # (1, 1, V, 2)
|
grid = verts.points_list()[i][None, None, :, :2] # (1, 1, V, 2)
|
||||||
else:
|
else:
|
||||||
raise ValueError("verts_or_meshes is invalid")
|
raise ValueError("verts_or_meshes is invalid")
|
||||||
feat_sampled_i = F.grid_sample(
|
feat_sampled_i = F.grid_sample(
|
||||||
@ -56,7 +57,9 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
|||||||
return out_feats
|
return out_feats
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000):
|
def init_meshes(
|
||||||
|
num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000
|
||||||
|
) -> Meshes:
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
verts_list = []
|
verts_list = []
|
||||||
faces_list = []
|
faces_list = []
|
||||||
@ -74,6 +77,20 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
return meshes
|
return meshes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_pointclouds(num_clouds: int = 10, num_points: int = 1000) -> Pointclouds:
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
points_list = []
|
||||||
|
for _ in range(num_clouds):
|
||||||
|
points = (
|
||||||
|
torch.rand((num_points, 3), dtype=torch.float32, device=device) * 2.0
|
||||||
|
- 1.0
|
||||||
|
) # points in the space of [-1, 1]
|
||||||
|
points_list.append(points)
|
||||||
|
pointclouds = Pointclouds(points=points_list)
|
||||||
|
|
||||||
|
return pointclouds
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"):
|
def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"):
|
||||||
H, W = [14, 28], [14, 28]
|
H, W = [14, 28], [14, 28]
|
||||||
@ -99,6 +116,27 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
|||||||
naive_out = TestVertAlign.vert_align_naive(feats[0], meshes, return_packed=True)
|
naive_out = TestVertAlign.vert_align_naive(feats[0], meshes, return_packed=True)
|
||||||
self.assertClose(out, naive_out)
|
self.assertClose(out, naive_out)
|
||||||
|
|
||||||
|
def test_vert_align_with_pointclouds(self):
|
||||||
|
"""
|
||||||
|
Test vert align vs naive implementation with meshes.
|
||||||
|
"""
|
||||||
|
pointclouds = TestVertAlign.init_pointclouds(10, 1000)
|
||||||
|
feats = TestVertAlign.init_feats(10, 256)
|
||||||
|
|
||||||
|
# feats in list
|
||||||
|
out = vert_align(feats, pointclouds, return_packed=True)
|
||||||
|
naive_out = TestVertAlign.vert_align_naive(
|
||||||
|
feats, pointclouds, return_packed=True
|
||||||
|
)
|
||||||
|
self.assertClose(out, naive_out)
|
||||||
|
|
||||||
|
# feats as tensor
|
||||||
|
out = vert_align(feats[0], pointclouds, return_packed=True)
|
||||||
|
naive_out = TestVertAlign.vert_align_naive(
|
||||||
|
feats[0], pointclouds, return_packed=True
|
||||||
|
)
|
||||||
|
self.assertClose(out, naive_out)
|
||||||
|
|
||||||
def test_vert_align_with_verts(self):
|
def test_vert_align_with_verts(self):
|
||||||
"""
|
"""
|
||||||
Test vert align vs naive implementation with verts as tensor.
|
Test vert align vs naive implementation with verts as tensor.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user