make points2volumes feature rescaling optional

Summary: Add option to not rescale the features, giving more control. https://github.com/facebookresearch/pytorch3d/issues/1137

Reviewed By: nikhilaravi

Differential Revision: D35219577

fbshipit-source-id: cbbb643b91b71bc908cedc6dac0f63f6d1355c85
This commit is contained in:
Jeremy Reizenstein 2022-04-13 04:39:47 -07:00 committed by Facebook GitHub Bot
parent 0a7c354dc1
commit 78fd5af1a6
3 changed files with 47 additions and 8 deletions

View File

@ -192,6 +192,7 @@ def add_pointclouds_to_volumes(
initial_volumes: "Volumes",
mode: str = "trilinear",
min_weight: float = 1e-4,
rescale_features: bool = True,
_python: bool = False,
) -> "Volumes":
"""
@ -250,6 +251,10 @@ def add_pointclouds_to_volumes(
min_weight: A scalar controlling the lowest possible total per-voxel
weight used to normalize the features accumulated in a voxel.
Only active for `mode==trilinear`.
rescale_features: If False, output features are just the sum of input and
added points. If True, they are averaged. In both cases,
output densities are just summed without rescaling, so
you may need to rescale them afterwards.
_python: Set to True to use a pure Python implementation, e.g. for test
purposes, which requires more memory and may be slower.
@ -286,6 +291,7 @@ def add_pointclouds_to_volumes(
grid_sizes=initial_volumes.get_grid_sizes(),
mask=mask,
mode=mode,
rescale_features=rescale_features,
_python=_python,
)
@ -303,6 +309,7 @@ def add_points_features_to_volume_densities_features(
min_weight: float = 1e-4,
mask: Optional[torch.Tensor] = None,
grid_sizes: Optional[torch.LongTensor] = None,
rescale_features: bool = True,
_python: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@ -345,6 +352,10 @@ def add_points_features_to_volume_densities_features(
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
spatial resolutions of each of the the non-flattened `volumes` tensors,
or None to indicate the whole volume is used for every batch element.
rescale_features: If False, output features are just the sum of input and
added points. If True, they are averaged. In both cases,
output densities are just summed without rescaling, so
you may need to rescale them afterwards.
_python: Set to True to use a pure Python implementation.
Returns:
volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)`
@ -401,12 +412,13 @@ def add_points_features_to_volume_densities_features(
True, # align_corners
splat,
)
if splat:
if rescale_features:
# divide each feature by the total weight of the votes
volume_features = volume_features / volume_densities.clamp(min_weight)
else:
# divide each feature by the total weight of the votes
volume_features = volume_features / volume_densities.clamp(1.0)
if splat:
volume_features = volume_features / volume_densities.clamp(min_weight)
else:
volume_features = volume_features / volume_densities.clamp(1.0)
return volume_features, volume_densities

View File

@ -192,12 +192,22 @@ class TestCaseMixin(unittest.TestCase):
self.fail(f"{msg} {err}")
self.fail(err)
def assertConstant(self, input: TensorOrArray, value: Real) -> None:
def assertConstant(
self, input: TensorOrArray, value: Real, *, atol: float = 0
) -> None:
"""
Asserts input is entirely filled with value.
Args:
input: tensor or array
value: expected value
atol: tolerance
"""
self.assertEqual(input.min(), value)
self.assertEqual(input.max(), value)
mn, mx = input.min(), input.max()
msg = f"values in range [{mn}, {mx}], not {value}, shape {input.shape}"
if atol == 0:
self.assertEqual(input.min(), value, msg=msg)
self.assertEqual(input.max(), value, msg=msg)
else:
self.assertGreater(input.min(), value - atol, msg=msg)
self.assertLess(input.max(), value + atol, msg=msg)

View File

@ -387,6 +387,23 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
)
self.assertClose(torch.sum(densities), torch.tensor(30 * 1000.0), atol=0.1)
def test_unscaled(self):
D = 5
P = 1000
B, C, H, W = 2, 3, D, D
densities = torch.zeros(B, 1, D, H, W)
features = torch.zeros(B, C, D, H, W)
volumes = Volumes(densities=densities, features=features)
points = torch.rand(B, 1000, 3) * (D - 1) - ((D - 1) * 0.5)
point_features = torch.rand(B, 1000, C)
pointclouds = Pointclouds(points=points, features=point_features)
volumes2 = add_pointclouds_to_volumes(
pointclouds, volumes, rescale_features=False
)
self.assertConstant(volumes2.densities().sum([2, 3, 4]) / P, 1, atol=1e-5)
self.assertConstant(volumes2.features().sum([2, 3, 4]) / P, 0.5, atol=0.03)
def _check_volume_slice_color_density(
self, V, split_dim, interp_mode, clr_gt, slice_type, border=3
):