mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Join points as batch
Summary: Function to join a list of pointclouds as a batch similar to the corresponding function for Meshes. Reviewed By: bottler Differential Revision: D33145906 fbshipit-source-id: 160639ebb5065e4fae1a1aa43117172719f3871b
This commit is contained in:
parent
eb2bbf8433
commit
262c1bfcd4
@ -1178,3 +1178,40 @@ class Pointclouds:
|
|||||||
|
|
||||||
coord_inside = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1])
|
coord_inside = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1])
|
||||||
return coord_inside.all(dim=-1)
|
return coord_inside.all(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]):
|
||||||
|
"""
|
||||||
|
Merge a list of Pointclouds objects into a single batched Pointclouds
|
||||||
|
object. All pointclouds must be on the same device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: List of Pointclouds objects each with batch dim [b1, b2, ..., bN]
|
||||||
|
Returns:
|
||||||
|
pointcloud: Poinclouds object with all input pointclouds collated into
|
||||||
|
a single object with batch dim = sum(b1, b2, ..., bN)
|
||||||
|
"""
|
||||||
|
if isinstance(pointclouds, Pointclouds) or not isinstance(pointclouds, Sequence):
|
||||||
|
raise ValueError("Wrong first argument to join_points_as_batch.")
|
||||||
|
|
||||||
|
device = pointclouds[0].device
|
||||||
|
if not all(p.device == device for p in pointclouds):
|
||||||
|
raise ValueError("Pointclouds must all be on the same device")
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
for field in ("points", "normals", "features"):
|
||||||
|
field_list = [getattr(p, field + "_list")() for p in pointclouds]
|
||||||
|
if None in field_list:
|
||||||
|
if field == "points":
|
||||||
|
raise ValueError("Pointclouds cannot have their points set to None!")
|
||||||
|
if not all(f is None for f in field_list):
|
||||||
|
raise ValueError(
|
||||||
|
f"Pointclouds in the batch have some fields '{field}'"
|
||||||
|
+ " defined and some set to None."
|
||||||
|
)
|
||||||
|
field_list = None
|
||||||
|
else:
|
||||||
|
field_list = [p for points in field_list for p in points]
|
||||||
|
kwargs[field] = field_list
|
||||||
|
|
||||||
|
return Pointclouds(**kwargs)
|
||||||
|
@ -12,7 +12,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
from pytorch3d.structures import utils as struct_utils
|
from pytorch3d.structures import utils as struct_utils
|
||||||
from pytorch3d.structures.pointclouds import Pointclouds
|
from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch
|
||||||
|
|
||||||
|
|
||||||
class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||||
@ -1098,6 +1098,70 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
|||||||
for length, points_ in zip(lengths_max_4, pcl_copy2.points_list()):
|
for length, points_ in zip(lengths_max_4, pcl_copy2.points_list()):
|
||||||
self.assertEqual(points_.shape, (length, 3))
|
self.assertEqual(points_.shape, (length, 3))
|
||||||
|
|
||||||
|
def test_join_pointclouds_as_batch(self):
|
||||||
|
"""
|
||||||
|
Test join_pointclouds_as_batch
|
||||||
|
"""
|
||||||
|
|
||||||
|
def check_item(x, y):
|
||||||
|
self.assertEqual(x is None, y is None)
|
||||||
|
if x is not None:
|
||||||
|
self.assertClose(torch.cat([x, x, x]), y)
|
||||||
|
|
||||||
|
def check_triple(points, points3):
|
||||||
|
"""
|
||||||
|
Verify that points3 is three copies of points.
|
||||||
|
"""
|
||||||
|
check_item(points.points_padded(), points3.points_padded())
|
||||||
|
check_item(points.normals_padded(), points3.normals_padded())
|
||||||
|
check_item(points.features_padded(), points3.features_padded())
|
||||||
|
|
||||||
|
lengths = [4, 5, 13, 3]
|
||||||
|
points = [torch.rand(length, 3) for length in lengths]
|
||||||
|
features = [torch.rand(length, 5) for length in lengths]
|
||||||
|
normals = [torch.rand(length, 3) for length in lengths]
|
||||||
|
|
||||||
|
# Test with normals and features present
|
||||||
|
pcl = Pointclouds(points=points, features=features, normals=normals)
|
||||||
|
pcl3 = join_pointclouds_as_batch([pcl] * 3)
|
||||||
|
check_triple(pcl, pcl3)
|
||||||
|
|
||||||
|
# Test with normals and features present for tensor backed pointclouds
|
||||||
|
N, P, D = 5, 30, 4
|
||||||
|
pcl = Pointclouds(
|
||||||
|
points=torch.rand(N, P, 3),
|
||||||
|
features=torch.rand(N, P, D),
|
||||||
|
normals=torch.rand(N, P, 3),
|
||||||
|
)
|
||||||
|
pcl3 = join_pointclouds_as_batch([pcl] * 3)
|
||||||
|
check_triple(pcl, pcl3)
|
||||||
|
|
||||||
|
# Test without normals
|
||||||
|
pcl_nonormals = Pointclouds(points=points, features=features)
|
||||||
|
pcl3 = join_pointclouds_as_batch([pcl_nonormals] * 3)
|
||||||
|
check_triple(pcl_nonormals, pcl3)
|
||||||
|
|
||||||
|
# Test without features
|
||||||
|
pcl_nofeats = Pointclouds(points=points, normals=normals)
|
||||||
|
pcl3 = join_pointclouds_as_batch([pcl_nofeats] * 3)
|
||||||
|
check_triple(pcl_nofeats, pcl3)
|
||||||
|
|
||||||
|
# Check error raised if all pointclouds in the batch
|
||||||
|
# are not consistent in including normals/features
|
||||||
|
with self.assertRaisesRegex(ValueError, "some set to None"):
|
||||||
|
join_pointclouds_as_batch([pcl, pcl_nonormals, pcl_nonormals])
|
||||||
|
with self.assertRaisesRegex(ValueError, "some set to None"):
|
||||||
|
join_pointclouds_as_batch([pcl, pcl_nofeats, pcl_nofeats])
|
||||||
|
|
||||||
|
# Check error if first input is a single pointclouds object
|
||||||
|
# instead of a list
|
||||||
|
with self.assertRaisesRegex(ValueError, "Wrong first argument"):
|
||||||
|
join_pointclouds_as_batch(pcl)
|
||||||
|
|
||||||
|
# Check error if all pointclouds are not on the same device
|
||||||
|
with self.assertRaisesRegex(ValueError, "same device"):
|
||||||
|
join_pointclouds_as_batch([pcl, pcl.to("cuda:0")])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_packed_with_init(
|
def compute_packed_with_init(
|
||||||
num_clouds: int = 10, max_p: int = 100, features: int = 300
|
num_clouds: int = 10, max_p: int = 100, features: int = 300
|
||||||
|
Loading…
x
Reference in New Issue
Block a user