diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index 6d17fb3e..bc014af5 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -1178,3 +1178,40 @@ class Pointclouds: coord_inside = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1]) return coord_inside.all(dim=-1) + + +def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]): + """ + Merge a list of Pointclouds objects into a single batched Pointclouds + object. All pointclouds must be on the same device. + + Args: + batch: List of Pointclouds objects each with batch dim [b1, b2, ..., bN] + Returns: + pointcloud: Poinclouds object with all input pointclouds collated into + a single object with batch dim = sum(b1, b2, ..., bN) + """ + if isinstance(pointclouds, Pointclouds) or not isinstance(pointclouds, Sequence): + raise ValueError("Wrong first argument to join_points_as_batch.") + + device = pointclouds[0].device + if not all(p.device == device for p in pointclouds): + raise ValueError("Pointclouds must all be on the same device") + + kwargs = {} + for field in ("points", "normals", "features"): + field_list = [getattr(p, field + "_list")() for p in pointclouds] + if None in field_list: + if field == "points": + raise ValueError("Pointclouds cannot have their points set to None!") + if not all(f is None for f in field_list): + raise ValueError( + f"Pointclouds in the batch have some fields '{field}'" + + " defined and some set to None." + ) + field_list = None + else: + field_list = [p for points in field_list for p in points] + kwargs[field] = field_list + + return Pointclouds(**kwargs) diff --git a/tests/test_pointclouds.py b/tests/test_pointclouds.py index d92cf895..9e3863b5 100644 --- a/tests/test_pointclouds.py +++ b/tests/test_pointclouds.py @@ -12,7 +12,7 @@ import numpy as np import torch from common_testing import TestCaseMixin from pytorch3d.structures import utils as struct_utils -from pytorch3d.structures.pointclouds import Pointclouds +from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch class TestPointclouds(TestCaseMixin, unittest.TestCase): @@ -1098,6 +1098,70 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): for length, points_ in zip(lengths_max_4, pcl_copy2.points_list()): self.assertEqual(points_.shape, (length, 3)) + def test_join_pointclouds_as_batch(self): + """ + Test join_pointclouds_as_batch + """ + + def check_item(x, y): + self.assertEqual(x is None, y is None) + if x is not None: + self.assertClose(torch.cat([x, x, x]), y) + + def check_triple(points, points3): + """ + Verify that points3 is three copies of points. + """ + check_item(points.points_padded(), points3.points_padded()) + check_item(points.normals_padded(), points3.normals_padded()) + check_item(points.features_padded(), points3.features_padded()) + + lengths = [4, 5, 13, 3] + points = [torch.rand(length, 3) for length in lengths] + features = [torch.rand(length, 5) for length in lengths] + normals = [torch.rand(length, 3) for length in lengths] + + # Test with normals and features present + pcl = Pointclouds(points=points, features=features, normals=normals) + pcl3 = join_pointclouds_as_batch([pcl] * 3) + check_triple(pcl, pcl3) + + # Test with normals and features present for tensor backed pointclouds + N, P, D = 5, 30, 4 + pcl = Pointclouds( + points=torch.rand(N, P, 3), + features=torch.rand(N, P, D), + normals=torch.rand(N, P, 3), + ) + pcl3 = join_pointclouds_as_batch([pcl] * 3) + check_triple(pcl, pcl3) + + # Test without normals + pcl_nonormals = Pointclouds(points=points, features=features) + pcl3 = join_pointclouds_as_batch([pcl_nonormals] * 3) + check_triple(pcl_nonormals, pcl3) + + # Test without features + pcl_nofeats = Pointclouds(points=points, normals=normals) + pcl3 = join_pointclouds_as_batch([pcl_nofeats] * 3) + check_triple(pcl_nofeats, pcl3) + + # Check error raised if all pointclouds in the batch + # are not consistent in including normals/features + with self.assertRaisesRegex(ValueError, "some set to None"): + join_pointclouds_as_batch([pcl, pcl_nonormals, pcl_nonormals]) + with self.assertRaisesRegex(ValueError, "some set to None"): + join_pointclouds_as_batch([pcl, pcl_nofeats, pcl_nofeats]) + + # Check error if first input is a single pointclouds object + # instead of a list + with self.assertRaisesRegex(ValueError, "Wrong first argument"): + join_pointclouds_as_batch(pcl) + + # Check error if all pointclouds are not on the same device + with self.assertRaisesRegex(ValueError, "same device"): + join_pointclouds_as_batch([pcl, pcl.to("cuda:0")]) + @staticmethod def compute_packed_with_init( num_clouds: int = 10, max_p: int = 100, features: int = 300