Join points as batch

Summary: Function to join a list of pointclouds as a batch similar to the corresponding function for Meshes.

Reviewed By: bottler

Differential Revision: D33145906

fbshipit-source-id: 160639ebb5065e4fae1a1aa43117172719f3871b
This commit is contained in:
Nikhila Ravi
2021-12-21 04:43:09 -08:00
committed by Facebook GitHub Bot
parent eb2bbf8433
commit 262c1bfcd4
2 changed files with 102 additions and 1 deletions

View File

@@ -12,7 +12,7 @@ import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.structures import utils as struct_utils
from pytorch3d.structures.pointclouds import Pointclouds
from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch
class TestPointclouds(TestCaseMixin, unittest.TestCase):
@@ -1098,6 +1098,70 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
for length, points_ in zip(lengths_max_4, pcl_copy2.points_list()):
self.assertEqual(points_.shape, (length, 3))
def test_join_pointclouds_as_batch(self):
"""
Test join_pointclouds_as_batch
"""
def check_item(x, y):
self.assertEqual(x is None, y is None)
if x is not None:
self.assertClose(torch.cat([x, x, x]), y)
def check_triple(points, points3):
"""
Verify that points3 is three copies of points.
"""
check_item(points.points_padded(), points3.points_padded())
check_item(points.normals_padded(), points3.normals_padded())
check_item(points.features_padded(), points3.features_padded())
lengths = [4, 5, 13, 3]
points = [torch.rand(length, 3) for length in lengths]
features = [torch.rand(length, 5) for length in lengths]
normals = [torch.rand(length, 3) for length in lengths]
# Test with normals and features present
pcl = Pointclouds(points=points, features=features, normals=normals)
pcl3 = join_pointclouds_as_batch([pcl] * 3)
check_triple(pcl, pcl3)
# Test with normals and features present for tensor backed pointclouds
N, P, D = 5, 30, 4
pcl = Pointclouds(
points=torch.rand(N, P, 3),
features=torch.rand(N, P, D),
normals=torch.rand(N, P, 3),
)
pcl3 = join_pointclouds_as_batch([pcl] * 3)
check_triple(pcl, pcl3)
# Test without normals
pcl_nonormals = Pointclouds(points=points, features=features)
pcl3 = join_pointclouds_as_batch([pcl_nonormals] * 3)
check_triple(pcl_nonormals, pcl3)
# Test without features
pcl_nofeats = Pointclouds(points=points, normals=normals)
pcl3 = join_pointclouds_as_batch([pcl_nofeats] * 3)
check_triple(pcl_nofeats, pcl3)
# Check error raised if all pointclouds in the batch
# are not consistent in including normals/features
with self.assertRaisesRegex(ValueError, "some set to None"):
join_pointclouds_as_batch([pcl, pcl_nonormals, pcl_nonormals])
with self.assertRaisesRegex(ValueError, "some set to None"):
join_pointclouds_as_batch([pcl, pcl_nofeats, pcl_nofeats])
# Check error if first input is a single pointclouds object
# instead of a list
with self.assertRaisesRegex(ValueError, "Wrong first argument"):
join_pointclouds_as_batch(pcl)
# Check error if all pointclouds are not on the same device
with self.assertRaisesRegex(ValueError, "same device"):
join_pointclouds_as_batch([pcl, pcl.to("cuda:0")])
@staticmethod
def compute_packed_with_init(
num_clouds: int = 10, max_p: int = 100, features: int = 300