make points2volumes feature rescaling optional

Summary: Add option to not rescale the features, giving more control. https://github.com/facebookresearch/pytorch3d/issues/1137

Reviewed By: nikhilaravi

Differential Revision: D35219577

fbshipit-source-id: cbbb643b91b71bc908cedc6dac0f63f6d1355c85
This commit is contained in:
Jeremy Reizenstein
2022-04-13 04:39:47 -07:00
committed by Facebook GitHub Bot
parent 0a7c354dc1
commit 78fd5af1a6
3 changed files with 47 additions and 8 deletions

View File

@@ -192,12 +192,22 @@ class TestCaseMixin(unittest.TestCase):
self.fail(f"{msg} {err}")
self.fail(err)
def assertConstant(self, input: TensorOrArray, value: Real) -> None:
def assertConstant(
self, input: TensorOrArray, value: Real, *, atol: float = 0
) -> None:
"""
Asserts input is entirely filled with value.
Args:
input: tensor or array
value: expected value
atol: tolerance
"""
self.assertEqual(input.min(), value)
self.assertEqual(input.max(), value)
mn, mx = input.min(), input.max()
msg = f"values in range [{mn}, {mx}], not {value}, shape {input.shape}"
if atol == 0:
self.assertEqual(input.min(), value, msg=msg)
self.assertEqual(input.max(), value, msg=msg)
else:
self.assertGreater(input.min(), value - atol, msg=msg)
self.assertLess(input.max(), value + atol, msg=msg)

View File

@@ -387,6 +387,23 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
)
self.assertClose(torch.sum(densities), torch.tensor(30 * 1000.0), atol=0.1)
def test_unscaled(self):
D = 5
P = 1000
B, C, H, W = 2, 3, D, D
densities = torch.zeros(B, 1, D, H, W)
features = torch.zeros(B, C, D, H, W)
volumes = Volumes(densities=densities, features=features)
points = torch.rand(B, 1000, 3) * (D - 1) - ((D - 1) * 0.5)
point_features = torch.rand(B, 1000, C)
pointclouds = Pointclouds(points=points, features=point_features)
volumes2 = add_pointclouds_to_volumes(
pointclouds, volumes, rescale_features=False
)
self.assertConstant(volumes2.densities().sum([2, 3, 4]) / P, 1, atol=1e-5)
self.assertConstant(volumes2.features().sum([2, 3, 4]) / P, 0.5, atol=0.03)
def _check_volume_slice_color_density(
self, V, split_dim, interp_mode, clr_gt, slice_type, border=3
):