CPU function for points2vols

Summary: Single C++ function for the core of points2vols, not used anywhere yet. Added ability to control align_corners and the weight of each point, which may be useful later.

Reviewed By: nikhilaravi

Differential Revision: D29548607

fbshipit-source-id: a5cda7ec2c14836624e7dfe744c4bbb3f3d3dfe2
This commit is contained in:
Jeremy Reizenstein
2021-10-01 11:57:07 -07:00
committed by Facebook GitHub Bot
parent c7c6deab86
commit 0dfc6e0eb8
5 changed files with 767 additions and 0 deletions

View File

@@ -5,12 +5,14 @@
# LICENSE file in the root directory of this source tree.
import unittest
from functools import partial
from typing import Tuple
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import add_pointclouds_to_volumes
from pytorch3d.ops.points_to_volumes import _points_to_volumes
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes
from pytorch3d.structures.pointclouds import Pointclouds
@@ -395,3 +397,140 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
# check that all per-slice avg errors vanish
self.assertClose(clr_diff, torch.zeros_like(clr_diff), atol=1e-2)
class TestRawFunction(TestCaseMixin, unittest.TestCase):
"""
Testing the _C.points_to_volumes function through its wrapper
_points_to_volumes.
"""
def setUp(self) -> None:
torch.manual_seed(42)
def test_grad_corners_splat_cpu(self):
self.do_gradcheck(torch.device("cpu"), True, True)
def test_grad_corners_round_cpu(self):
self.do_gradcheck(torch.device("cpu"), False, True)
def test_grad_splat_cpu(self):
self.do_gradcheck(torch.device("cpu"), True, False)
def test_grad_round_cpu(self):
self.do_gradcheck(torch.device("cpu"), False, False)
def do_gradcheck(self, device, splat: bool, align_corners: bool):
"""
Use gradcheck to verify the gradient of _points_to_volumes
with random input.
"""
N, C, D, H, W, P = 2, 4, 5, 6, 7, 5
points_3d = (
torch.rand((N, P, 3), device=device, dtype=torch.float64) * 0.8 + 0.1
)
points_features = torch.rand((N, P, C), device=device, dtype=torch.float64)
volume_densities = torch.zeros((N, 1, D, H, W), device=device)
volume_features = torch.zeros((N, C, D, H, W), device=device)
volume_densities_scale = torch.rand_like(volume_densities)
volume_features_scale = torch.rand_like(volume_features)
grid_sizes = torch.tensor([D, H, W], dtype=torch.int64, device=device).expand(
N, 3
)
mask = torch.ones((N, P), device=device)
mask[:, 0] = 0
align_corners = False
def f(points_3d_, points_features_):
(volume_densities_, volume_features_) = _points_to_volumes(
points_3d_.to(torch.float32),
points_features_.to(torch.float32),
volume_densities.clone(),
volume_features.clone(),
grid_sizes,
2.0,
mask,
align_corners,
splat,
)
density = (volume_densities_ * volume_densities_scale).sum()
features = (volume_features_ * volume_features_scale).sum()
return density, features
base = f(points_3d.clone(), points_features.clone())
self.assertGreater(base[0], 0)
self.assertGreater(base[1], 0)
points_features.requires_grad = True
if splat:
points_3d.requires_grad = True
torch.autograd.gradcheck(
f,
(points_3d, points_features),
check_undefined_grad=False,
eps=2e-4,
atol=0.01,
)
else:
torch.autograd.gradcheck(
partial(f, points_3d),
points_features,
check_undefined_grad=False,
eps=2e-3,
atol=0.001,
)
def test_single_corners_round_cpu(self):
self.single_point(torch.device("cpu"), False, True)
def test_single_corners_splat_cpu(self):
self.single_point(torch.device("cpu"), True, True)
def test_single_round_cpu(self):
self.single_point(torch.device("cpu"), False, False)
def test_single_splat_cpu(self):
self.single_point(torch.device("cpu"), True, False)
def single_point(self, device, splat: bool, align_corners: bool):
"""
Check the outcome of _points_to_volumes where a single point
exists which lines up with a single voxel.
"""
D, H, W = (6, 6, 11) if align_corners else (5, 5, 10)
N, C, P = 1, 1, 1
if align_corners:
points_3d = torch.tensor([[[-0.2, 0.2, -0.2]]], device=device)
else:
points_3d = torch.tensor([[[-0.3, 0.4, -0.4]]], device=device)
points_features = torch.zeros((N, P, C), device=device)
volume_densities = torch.zeros((N, 1, D, H, W), device=device)
volume_densities_expected = torch.zeros((N, 1, D, H, W), device=device)
volume_features = torch.zeros((N, C, D, H, W), device=device)
grid_sizes = torch.tensor([D, H, W], dtype=torch.int64, device=device).expand(
N, 3
)
mask = torch.ones((N, P), device=device)
point_weight = 19.0
volume_densities_, volume_features_ = _points_to_volumes(
points_3d,
points_features,
volume_densities,
volume_features,
grid_sizes,
point_weight,
mask,
align_corners,
splat,
)
self.assertIs(volume_densities, volume_densities_)
self.assertIs(volume_features, volume_features_)
if align_corners:
volume_densities_expected[0, 0, 2, 3, 4] = point_weight
else:
volume_densities_expected[0, 0, 1, 3, 3] = point_weight
self.assertClose(volume_densities, volume_densities_expected)