mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Join points as batch
Summary: Function to join a list of pointclouds as a batch similar to the corresponding function for Meshes. Reviewed By: bottler Differential Revision: D33145906 fbshipit-source-id: 160639ebb5065e4fae1a1aa43117172719f3871b
This commit is contained in:
		
							parent
							
								
									eb2bbf8433
								
							
						
					
					
						commit
						262c1bfcd4
					
				@ -1178,3 +1178,40 @@ class Pointclouds:
 | 
			
		||||
 | 
			
		||||
        coord_inside = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1])
 | 
			
		||||
        return coord_inside.all(dim=-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]):
 | 
			
		||||
    """
 | 
			
		||||
    Merge a list of Pointclouds objects into a single batched Pointclouds
 | 
			
		||||
    object. All pointclouds must be on the same device.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        batch: List of Pointclouds objects each with batch dim [b1, b2, ..., bN]
 | 
			
		||||
    Returns:
 | 
			
		||||
        pointcloud: Poinclouds object with all input pointclouds collated into
 | 
			
		||||
            a single object with batch dim = sum(b1, b2, ..., bN)
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(pointclouds, Pointclouds) or not isinstance(pointclouds, Sequence):
 | 
			
		||||
        raise ValueError("Wrong first argument to join_points_as_batch.")
 | 
			
		||||
 | 
			
		||||
    device = pointclouds[0].device
 | 
			
		||||
    if not all(p.device == device for p in pointclouds):
 | 
			
		||||
        raise ValueError("Pointclouds must all be on the same device")
 | 
			
		||||
 | 
			
		||||
    kwargs = {}
 | 
			
		||||
    for field in ("points", "normals", "features"):
 | 
			
		||||
        field_list = [getattr(p, field + "_list")() for p in pointclouds]
 | 
			
		||||
        if None in field_list:
 | 
			
		||||
            if field == "points":
 | 
			
		||||
                raise ValueError("Pointclouds cannot have their points set to None!")
 | 
			
		||||
            if not all(f is None for f in field_list):
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"Pointclouds in the batch have some fields '{field}'"
 | 
			
		||||
                    + " defined and some set to None."
 | 
			
		||||
                )
 | 
			
		||||
            field_list = None
 | 
			
		||||
        else:
 | 
			
		||||
            field_list = [p for points in field_list for p in points]
 | 
			
		||||
        kwargs[field] = field_list
 | 
			
		||||
 | 
			
		||||
    return Pointclouds(**kwargs)
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,7 @@ import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin
 | 
			
		||||
from pytorch3d.structures import utils as struct_utils
 | 
			
		||||
from pytorch3d.structures.pointclouds import Pointclouds
 | 
			
		||||
from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestPointclouds(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
@ -1098,6 +1098,70 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        for length, points_ in zip(lengths_max_4, pcl_copy2.points_list()):
 | 
			
		||||
            self.assertEqual(points_.shape, (length, 3))
 | 
			
		||||
 | 
			
		||||
    def test_join_pointclouds_as_batch(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test join_pointclouds_as_batch
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def check_item(x, y):
 | 
			
		||||
            self.assertEqual(x is None, y is None)
 | 
			
		||||
            if x is not None:
 | 
			
		||||
                self.assertClose(torch.cat([x, x, x]), y)
 | 
			
		||||
 | 
			
		||||
        def check_triple(points, points3):
 | 
			
		||||
            """
 | 
			
		||||
            Verify that points3 is three copies of points.
 | 
			
		||||
            """
 | 
			
		||||
            check_item(points.points_padded(), points3.points_padded())
 | 
			
		||||
            check_item(points.normals_padded(), points3.normals_padded())
 | 
			
		||||
            check_item(points.features_padded(), points3.features_padded())
 | 
			
		||||
 | 
			
		||||
        lengths = [4, 5, 13, 3]
 | 
			
		||||
        points = [torch.rand(length, 3) for length in lengths]
 | 
			
		||||
        features = [torch.rand(length, 5) for length in lengths]
 | 
			
		||||
        normals = [torch.rand(length, 3) for length in lengths]
 | 
			
		||||
 | 
			
		||||
        # Test with normals and features present
 | 
			
		||||
        pcl = Pointclouds(points=points, features=features, normals=normals)
 | 
			
		||||
        pcl3 = join_pointclouds_as_batch([pcl] * 3)
 | 
			
		||||
        check_triple(pcl, pcl3)
 | 
			
		||||
 | 
			
		||||
        # Test with normals and features present for tensor backed pointclouds
 | 
			
		||||
        N, P, D = 5, 30, 4
 | 
			
		||||
        pcl = Pointclouds(
 | 
			
		||||
            points=torch.rand(N, P, 3),
 | 
			
		||||
            features=torch.rand(N, P, D),
 | 
			
		||||
            normals=torch.rand(N, P, 3),
 | 
			
		||||
        )
 | 
			
		||||
        pcl3 = join_pointclouds_as_batch([pcl] * 3)
 | 
			
		||||
        check_triple(pcl, pcl3)
 | 
			
		||||
 | 
			
		||||
        # Test without normals
 | 
			
		||||
        pcl_nonormals = Pointclouds(points=points, features=features)
 | 
			
		||||
        pcl3 = join_pointclouds_as_batch([pcl_nonormals] * 3)
 | 
			
		||||
        check_triple(pcl_nonormals, pcl3)
 | 
			
		||||
 | 
			
		||||
        # Test without features
 | 
			
		||||
        pcl_nofeats = Pointclouds(points=points, normals=normals)
 | 
			
		||||
        pcl3 = join_pointclouds_as_batch([pcl_nofeats] * 3)
 | 
			
		||||
        check_triple(pcl_nofeats, pcl3)
 | 
			
		||||
 | 
			
		||||
        # Check error raised if all pointclouds in the batch
 | 
			
		||||
        # are not consistent in including normals/features
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "some set to None"):
 | 
			
		||||
            join_pointclouds_as_batch([pcl, pcl_nonormals, pcl_nonormals])
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "some set to None"):
 | 
			
		||||
            join_pointclouds_as_batch([pcl, pcl_nofeats, pcl_nofeats])
 | 
			
		||||
 | 
			
		||||
        # Check error if first input is a single pointclouds object
 | 
			
		||||
        # instead of a list
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "Wrong first argument"):
 | 
			
		||||
            join_pointclouds_as_batch(pcl)
 | 
			
		||||
 | 
			
		||||
        # Check error if all pointclouds are not on the same device
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "same device"):
 | 
			
		||||
            join_pointclouds_as_batch([pcl, pcl.to("cuda:0")])
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def compute_packed_with_init(
 | 
			
		||||
        num_clouds: int = 10, max_p: int = 100, features: int = 300
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user