From ee2b2feb9891a26939a688fd3c57d03462d7f773 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 1 Oct 2021 11:57:07 -0700 Subject: [PATCH] Use C++/CUDA in points2vols MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Move the core of add_points_to_volumes to the new C++/CUDA implementation. Add new flag to let the user stop this happening. Avoids copies. About a 30% speedup on the larger cases, up to 50% on the smaller cases. New timings ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[25, 25, 25]_1000 4575 12591 110 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[25, 25, 25]_10000 25468 29186 20 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[25, 25, 25]_100000 202085 209897 3 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[101, 111, 121]_1000 46059 48188 11 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[101, 111, 121]_10000 83759 95669 7 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[101, 111, 121]_100000 326056 339393 2 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[25, 25, 25]_1000 2379 4738 211 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[25, 25, 25]_10000 12100 63099 42 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[25, 25, 25]_100000 63323 63737 8 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[101, 111, 121]_1000 45216 45479 12 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[101, 111, 121]_10000 57205 58524 9 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[101, 111, 121]_100000 139499 139926 4 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[25, 25, 25]_1000 40129 40431 13 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[25, 25, 25]_10000 204949 239293 3 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[25, 25, 25]_100000 1664541 1664541 1 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[101, 111, 121]_1000 391573 395108 2 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[101, 111, 121]_10000 674869 674869 1 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[101, 111, 121]_100000 2713632 2713632 1 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[25, 25, 25]_1000 12726 13506 40 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[25, 25, 25]_10000 73103 73299 7 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[25, 25, 25]_100000 598634 598634 1 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[101, 111, 121]_1000 398742 399256 2 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[101, 111, 121]_10000 543129 543129 1 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[101, 111, 121]_100000 1242956 1242956 1 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[25, 25, 25]_1000 1814 8884 276 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[25, 25, 25]_10000 1996 8851 251 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[25, 25, 25]_100000 4608 11529 109 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[101, 111, 121]_1000 5183 12508 97 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[101, 111, 121]_10000 7106 14077 71 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[101, 111, 121]_100000 25914 31818 20 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[25, 25, 25]_1000 1778 8823 282 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[25, 25, 25]_10000 1825 8613 274 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[25, 25, 25]_100000 3154 10161 159 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[101, 111, 121]_1000 4888 9404 103 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[101, 111, 121]_10000 5194 9963 97 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[101, 111, 121]_100000 8109 14933 62 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[25, 25, 25]_1000 3320 10306 151 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[25, 25, 25]_10000 7003 8595 72 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[25, 25, 25]_100000 49140 52957 11 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[101, 111, 121]_1000 35890 36918 14 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[101, 111, 121]_10000 58890 59337 9 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[101, 111, 121]_100000 286878 287600 2 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[25, 25, 25]_1000 2484 8805 202 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[25, 25, 25]_10000 3967 9090 127 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[25, 25, 25]_100000 19423 19799 26 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[101, 111, 121]_1000 33228 33329 16 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[101, 111, 121]_10000 37292 37370 14 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[101, 111, 121]_100000 73550 74017 7 -------------------------------------------------------------------------------- ``` Previous timings ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[25, 25, 25]_1000 10100 46422 50 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[25, 25, 25]_10000 28442 32100 18 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[25, 25, 25]_100000 241127 254269 3 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[101, 111, 121]_1000 54149 79480 10 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[101, 111, 121]_10000 125459 212734 4 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[101, 111, 121]_100000 512739 512739 1 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[25, 25, 25]_1000 2866 13365 175 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[25, 25, 25]_10000 7026 12604 72 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[25, 25, 25]_100000 48822 55607 11 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[101, 111, 121]_1000 38098 38576 14 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[101, 111, 121]_10000 48006 54120 11 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[101, 111, 121]_100000 131563 138536 4 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[25, 25, 25]_1000 64615 91735 8 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[25, 25, 25]_10000 228815 246095 3 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[25, 25, 25]_100000 3086615 3086615 1 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[101, 111, 121]_1000 464298 465292 2 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[101, 111, 121]_10000 1053440 1053440 1 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[101, 111, 121]_100000 6736236 6736236 1 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[25, 25, 25]_1000 11940 12440 42 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[25, 25, 25]_10000 56641 58051 9 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[25, 25, 25]_100000 711492 711492 1 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[101, 111, 121]_1000 326437 329846 2 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[101, 111, 121]_10000 418514 427911 2 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[101, 111, 121]_100000 1524285 1524285 1 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[25, 25, 25]_1000 5949 13602 85 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[25, 25, 25]_10000 5817 13001 86 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[25, 25, 25]_100000 23833 25971 21 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[101, 111, 121]_1000 9029 16178 56 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[101, 111, 121]_10000 11595 18601 44 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[101, 111, 121]_100000 46986 47344 11 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[25, 25, 25]_1000 2554 9747 196 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[25, 25, 25]_10000 2676 9537 187 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[25, 25, 25]_100000 6567 14179 77 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[101, 111, 121]_1000 5840 12811 86 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[101, 111, 121]_10000 6102 13128 82 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[101, 111, 121]_100000 11945 11995 42 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[25, 25, 25]_1000 7642 13671 66 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[25, 25, 25]_10000 25190 25260 20 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[25, 25, 25]_100000 212018 212134 3 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[101, 111, 121]_1000 40421 45692 13 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[101, 111, 121]_10000 92078 92132 6 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[101, 111, 121]_100000 457211 457229 2 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[25, 25, 25]_1000 3574 10377 140 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[25, 25, 25]_10000 7222 13023 70 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[25, 25, 25]_100000 48127 48165 11 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[101, 111, 121]_1000 34732 35295 15 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[101, 111, 121]_10000 43050 51064 12 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[101, 111, 121]_100000 106028 106058 5 -------------------------------------------------------------------------------- ``` Reviewed By: nikhilaravi Differential Revision: D29548609 fbshipit-source-id: 7026e832ea299145c3f6b55687f3c1601294f5c0 --- pytorch3d/ops/points_to_volumes.py | 75 ++++++++++++++++++++++++++++-- tests/test_points_to_volumes.py | 8 +++- 2 files changed, 76 insertions(+), 7 deletions(-) diff --git a/pytorch3d/ops/points_to_volumes.py b/pytorch3d/ops/points_to_volumes.py index 2b80a0fb..4e853fab 100644 --- a/pytorch3d/ops/points_to_volumes.py +++ b/pytorch3d/ops/points_to_volumes.py @@ -192,6 +192,7 @@ def add_pointclouds_to_volumes( initial_volumes: "Volumes", mode: str = "trilinear", min_weight: float = 1e-4, + _python: bool = False, ) -> "Volumes": """ Add a batch of point clouds represented with a `Pointclouds` structure @@ -249,6 +250,8 @@ def add_pointclouds_to_volumes( min_weight: A scalar controlling the lowest possible total per-voxel weight used to normalize the features accumulated in a voxel. Only active for `mode==trilinear`. + _python: Set to True to use a pure Python implementation, e.g. for test + purposes, which requires more memory and may be slower. Returns: updated_volumes: Output `Volumes` structure containing the conversion result. @@ -283,6 +286,7 @@ def add_pointclouds_to_volumes( grid_sizes=initial_volumes.get_grid_sizes(), mask=mask, mode=mode, + _python=_python, ) return initial_volumes.update_padded( @@ -299,6 +303,7 @@ def add_points_features_to_volume_densities_features( min_weight: float = 1e-4, mask: Optional[torch.Tensor] = None, grid_sizes: Optional[torch.LongTensor] = None, + _python: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert a batch of point clouds represented with tensors of per-point @@ -340,6 +345,7 @@ def add_points_features_to_volume_densities_features( grid_sizes: `LongTensor` of shape (minibatch, 3) representing the spatial resolutions of each of the the non-flattened `volumes` tensors, or None to indicate the whole volume is used for every batch element. + _python: Set to True to use a pure Python implementation. Returns: volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)` volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)` @@ -362,6 +368,66 @@ def add_points_features_to_volume_densities_features( .expand(volume_densities.shape[0], 3) ) + if _python: + return _add_points_features_to_volume_densities_features_python( + points_3d=points_3d, + points_features=points_features, + volume_densities=volume_densities, + volume_features=volume_features, + mode=mode, + min_weight=min_weight, + mask=mask, + grid_sizes=grid_sizes, + ) + + if mode == "trilinear": + splat = True + elif mode == "nearest": + splat = False + else: + raise ValueError('No such interpolation mode "%s"' % mode) + volume_densities, volume_features = _points_to_volumes( + points_3d, + points_features, + volume_densities, + volume_features, + grid_sizes, + 1.0, # point_weight + mask, + True, # align_corners + splat, + ) + if splat: + # divide each feature by the total weight of the votes + volume_features = volume_features / volume_densities.clamp(min_weight) + else: + # divide each feature by the total weight of the votes + volume_features = volume_features / volume_densities.clamp(1.0) + + return volume_features, volume_densities + + +def _add_points_features_to_volume_densities_features_python( + *, + points_3d: torch.Tensor, + points_features: torch.Tensor, + volume_densities: torch.Tensor, + volume_features: Optional[torch.Tensor], + mode: str, + min_weight: float, + mask: Optional[torch.Tensor], + grid_sizes: torch.LongTensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Python implementation for add_points_features_to_volume_densities_features. + + Returns: + volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)` + volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)` + containing the total amount of votes cast to each of the voxels. + """ + ba, n_points, feature_dim = points_features.shape + # flatten densities and features v_shape = volume_densities.shape[2:] volume_densities_flatten = volume_densities.view(ba, -1, 1) @@ -376,7 +442,7 @@ def add_points_features_to_volume_densities_features( volume_features_flatten = volume_features.view(ba, feature_dim, n_voxels) if mode == "trilinear": # do the splatting (trilinear interp) - volume_features, volume_densities = splat_points_to_volumes( + volume_features, volume_densities = _splat_points_to_volumes( points_3d, points_features, volume_densities_flatten, @@ -386,7 +452,7 @@ def add_points_features_to_volume_densities_features( min_weight=min_weight, ) elif mode == "nearest": # nearest neighbor interp - volume_features, volume_densities = round_points_to_volumes( + volume_features, volume_densities = _round_points_to_volumes( points_3d, points_features, volume_densities_flatten, @@ -400,7 +466,6 @@ def add_points_features_to_volume_densities_features( # reshape into the volume shape volume_features = volume_features.view(ba, feature_dim, *v_shape) volume_densities = volume_densities.view(ba, 1, *v_shape) - return volume_features, volume_densities @@ -441,7 +506,7 @@ def _check_points_to_volumes_inputs( ) -def splat_points_to_volumes( +def _splat_points_to_volumes( points_3d: torch.Tensor, points_features: torch.Tensor, volume_densities: torch.Tensor, @@ -574,7 +639,7 @@ def splat_points_to_volumes( return volume_features, volume_densities -def round_points_to_volumes( +def _round_points_to_volumes( points_3d: torch.Tensor, points_features: torch.Tensor, volume_densities: torch.Tensor, diff --git a/tests/test_points_to_volumes.py b/tests/test_points_to_volumes.py index e417781e..0e190ae6 100644 --- a/tests/test_points_to_volumes.py +++ b/tests/test_points_to_volumes.py @@ -6,6 +6,7 @@ import unittest from functools import partial +from itertools import product from typing import Tuple import numpy as np @@ -254,7 +255,7 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase): for volume_size in ([25, 25, 25], [30, 25, 15]): - for interp_mode in ("trilinear", "nearest"): + for python, interp_mode in product([True, False], ["trilinear", "nearest"]): (pointclouds, initial_volumes) = init_volume_boundary_pointcloud( volume_size=volume_size, @@ -266,7 +267,10 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase): ) volumes = add_pointclouds_to_volumes( - pointclouds, initial_volumes, mode=interp_mode + pointclouds, + initial_volumes, + mode=interp_mode, + _python=python, ) V_color, V_density = volumes.features(), volumes.densities()