diff --git a/pytorch3d/ops/points_to_volumes.py b/pytorch3d/ops/points_to_volumes.py index 2b80a0fb..4e853fab 100644 --- a/pytorch3d/ops/points_to_volumes.py +++ b/pytorch3d/ops/points_to_volumes.py @@ -192,6 +192,7 @@ def add_pointclouds_to_volumes( initial_volumes: "Volumes", mode: str = "trilinear", min_weight: float = 1e-4, + _python: bool = False, ) -> "Volumes": """ Add a batch of point clouds represented with a `Pointclouds` structure @@ -249,6 +250,8 @@ def add_pointclouds_to_volumes( min_weight: A scalar controlling the lowest possible total per-voxel weight used to normalize the features accumulated in a voxel. Only active for `mode==trilinear`. + _python: Set to True to use a pure Python implementation, e.g. for test + purposes, which requires more memory and may be slower. Returns: updated_volumes: Output `Volumes` structure containing the conversion result. @@ -283,6 +286,7 @@ def add_pointclouds_to_volumes( grid_sizes=initial_volumes.get_grid_sizes(), mask=mask, mode=mode, + _python=_python, ) return initial_volumes.update_padded( @@ -299,6 +303,7 @@ def add_points_features_to_volume_densities_features( min_weight: float = 1e-4, mask: Optional[torch.Tensor] = None, grid_sizes: Optional[torch.LongTensor] = None, + _python: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert a batch of point clouds represented with tensors of per-point @@ -340,6 +345,7 @@ def add_points_features_to_volume_densities_features( grid_sizes: `LongTensor` of shape (minibatch, 3) representing the spatial resolutions of each of the the non-flattened `volumes` tensors, or None to indicate the whole volume is used for every batch element. + _python: Set to True to use a pure Python implementation. Returns: volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)` volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)` @@ -362,6 +368,66 @@ def add_points_features_to_volume_densities_features( .expand(volume_densities.shape[0], 3) ) + if _python: + return _add_points_features_to_volume_densities_features_python( + points_3d=points_3d, + points_features=points_features, + volume_densities=volume_densities, + volume_features=volume_features, + mode=mode, + min_weight=min_weight, + mask=mask, + grid_sizes=grid_sizes, + ) + + if mode == "trilinear": + splat = True + elif mode == "nearest": + splat = False + else: + raise ValueError('No such interpolation mode "%s"' % mode) + volume_densities, volume_features = _points_to_volumes( + points_3d, + points_features, + volume_densities, + volume_features, + grid_sizes, + 1.0, # point_weight + mask, + True, # align_corners + splat, + ) + if splat: + # divide each feature by the total weight of the votes + volume_features = volume_features / volume_densities.clamp(min_weight) + else: + # divide each feature by the total weight of the votes + volume_features = volume_features / volume_densities.clamp(1.0) + + return volume_features, volume_densities + + +def _add_points_features_to_volume_densities_features_python( + *, + points_3d: torch.Tensor, + points_features: torch.Tensor, + volume_densities: torch.Tensor, + volume_features: Optional[torch.Tensor], + mode: str, + min_weight: float, + mask: Optional[torch.Tensor], + grid_sizes: torch.LongTensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Python implementation for add_points_features_to_volume_densities_features. + + Returns: + volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)` + volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)` + containing the total amount of votes cast to each of the voxels. + """ + ba, n_points, feature_dim = points_features.shape + # flatten densities and features v_shape = volume_densities.shape[2:] volume_densities_flatten = volume_densities.view(ba, -1, 1) @@ -376,7 +442,7 @@ def add_points_features_to_volume_densities_features( volume_features_flatten = volume_features.view(ba, feature_dim, n_voxels) if mode == "trilinear": # do the splatting (trilinear interp) - volume_features, volume_densities = splat_points_to_volumes( + volume_features, volume_densities = _splat_points_to_volumes( points_3d, points_features, volume_densities_flatten, @@ -386,7 +452,7 @@ def add_points_features_to_volume_densities_features( min_weight=min_weight, ) elif mode == "nearest": # nearest neighbor interp - volume_features, volume_densities = round_points_to_volumes( + volume_features, volume_densities = _round_points_to_volumes( points_3d, points_features, volume_densities_flatten, @@ -400,7 +466,6 @@ def add_points_features_to_volume_densities_features( # reshape into the volume shape volume_features = volume_features.view(ba, feature_dim, *v_shape) volume_densities = volume_densities.view(ba, 1, *v_shape) - return volume_features, volume_densities @@ -441,7 +506,7 @@ def _check_points_to_volumes_inputs( ) -def splat_points_to_volumes( +def _splat_points_to_volumes( points_3d: torch.Tensor, points_features: torch.Tensor, volume_densities: torch.Tensor, @@ -574,7 +639,7 @@ def splat_points_to_volumes( return volume_features, volume_densities -def round_points_to_volumes( +def _round_points_to_volumes( points_3d: torch.Tensor, points_features: torch.Tensor, volume_densities: torch.Tensor, diff --git a/tests/test_points_to_volumes.py b/tests/test_points_to_volumes.py index e417781e..0e190ae6 100644 --- a/tests/test_points_to_volumes.py +++ b/tests/test_points_to_volumes.py @@ -6,6 +6,7 @@ import unittest from functools import partial +from itertools import product from typing import Tuple import numpy as np @@ -254,7 +255,7 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase): for volume_size in ([25, 25, 25], [30, 25, 15]): - for interp_mode in ("trilinear", "nearest"): + for python, interp_mode in product([True, False], ["trilinear", "nearest"]): (pointclouds, initial_volumes) = init_volume_boundary_pointcloud( volume_size=volume_size, @@ -266,7 +267,10 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase): ) volumes = add_pointclouds_to_volumes( - pointclouds, initial_volumes, mode=interp_mode + pointclouds, + initial_volumes, + mode=interp_mode, + _python=python, ) V_color, V_density = volumes.features(), volumes.densities()