mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
CPU function for points2vols
Summary: Single C++ function for the core of points2vols, not used anywhere yet. Added ability to control align_corners and the weight of each point, which may be useful later. Reviewed By: nikhilaravi Differential Revision: D29548607 fbshipit-source-id: a5cda7ec2c14836624e7dfe744c4bbb3f3d3dfe2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c7c6deab86
commit
0dfc6e0eb8
@@ -5,12 +5,14 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops import add_pointclouds_to_volumes
|
||||
from pytorch3d.ops.points_to_volumes import _points_to_volumes
|
||||
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
@@ -395,3 +397,140 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# check that all per-slice avg errors vanish
|
||||
self.assertClose(clr_diff, torch.zeros_like(clr_diff), atol=1e-2)
|
||||
|
||||
|
||||
class TestRawFunction(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
Testing the _C.points_to_volumes function through its wrapper
|
||||
_points_to_volumes.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_grad_corners_splat_cpu(self):
|
||||
self.do_gradcheck(torch.device("cpu"), True, True)
|
||||
|
||||
def test_grad_corners_round_cpu(self):
|
||||
self.do_gradcheck(torch.device("cpu"), False, True)
|
||||
|
||||
def test_grad_splat_cpu(self):
|
||||
self.do_gradcheck(torch.device("cpu"), True, False)
|
||||
|
||||
def test_grad_round_cpu(self):
|
||||
self.do_gradcheck(torch.device("cpu"), False, False)
|
||||
|
||||
def do_gradcheck(self, device, splat: bool, align_corners: bool):
|
||||
"""
|
||||
Use gradcheck to verify the gradient of _points_to_volumes
|
||||
with random input.
|
||||
"""
|
||||
N, C, D, H, W, P = 2, 4, 5, 6, 7, 5
|
||||
points_3d = (
|
||||
torch.rand((N, P, 3), device=device, dtype=torch.float64) * 0.8 + 0.1
|
||||
)
|
||||
points_features = torch.rand((N, P, C), device=device, dtype=torch.float64)
|
||||
volume_densities = torch.zeros((N, 1, D, H, W), device=device)
|
||||
volume_features = torch.zeros((N, C, D, H, W), device=device)
|
||||
volume_densities_scale = torch.rand_like(volume_densities)
|
||||
volume_features_scale = torch.rand_like(volume_features)
|
||||
grid_sizes = torch.tensor([D, H, W], dtype=torch.int64, device=device).expand(
|
||||
N, 3
|
||||
)
|
||||
mask = torch.ones((N, P), device=device)
|
||||
mask[:, 0] = 0
|
||||
align_corners = False
|
||||
|
||||
def f(points_3d_, points_features_):
|
||||
(volume_densities_, volume_features_) = _points_to_volumes(
|
||||
points_3d_.to(torch.float32),
|
||||
points_features_.to(torch.float32),
|
||||
volume_densities.clone(),
|
||||
volume_features.clone(),
|
||||
grid_sizes,
|
||||
2.0,
|
||||
mask,
|
||||
align_corners,
|
||||
splat,
|
||||
)
|
||||
density = (volume_densities_ * volume_densities_scale).sum()
|
||||
features = (volume_features_ * volume_features_scale).sum()
|
||||
return density, features
|
||||
|
||||
base = f(points_3d.clone(), points_features.clone())
|
||||
self.assertGreater(base[0], 0)
|
||||
self.assertGreater(base[1], 0)
|
||||
|
||||
points_features.requires_grad = True
|
||||
if splat:
|
||||
points_3d.requires_grad = True
|
||||
torch.autograd.gradcheck(
|
||||
f,
|
||||
(points_3d, points_features),
|
||||
check_undefined_grad=False,
|
||||
eps=2e-4,
|
||||
atol=0.01,
|
||||
)
|
||||
else:
|
||||
torch.autograd.gradcheck(
|
||||
partial(f, points_3d),
|
||||
points_features,
|
||||
check_undefined_grad=False,
|
||||
eps=2e-3,
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
def test_single_corners_round_cpu(self):
|
||||
self.single_point(torch.device("cpu"), False, True)
|
||||
|
||||
def test_single_corners_splat_cpu(self):
|
||||
self.single_point(torch.device("cpu"), True, True)
|
||||
|
||||
def test_single_round_cpu(self):
|
||||
self.single_point(torch.device("cpu"), False, False)
|
||||
|
||||
def test_single_splat_cpu(self):
|
||||
self.single_point(torch.device("cpu"), True, False)
|
||||
|
||||
def single_point(self, device, splat: bool, align_corners: bool):
|
||||
"""
|
||||
Check the outcome of _points_to_volumes where a single point
|
||||
exists which lines up with a single voxel.
|
||||
"""
|
||||
D, H, W = (6, 6, 11) if align_corners else (5, 5, 10)
|
||||
N, C, P = 1, 1, 1
|
||||
if align_corners:
|
||||
points_3d = torch.tensor([[[-0.2, 0.2, -0.2]]], device=device)
|
||||
else:
|
||||
points_3d = torch.tensor([[[-0.3, 0.4, -0.4]]], device=device)
|
||||
points_features = torch.zeros((N, P, C), device=device)
|
||||
volume_densities = torch.zeros((N, 1, D, H, W), device=device)
|
||||
volume_densities_expected = torch.zeros((N, 1, D, H, W), device=device)
|
||||
volume_features = torch.zeros((N, C, D, H, W), device=device)
|
||||
grid_sizes = torch.tensor([D, H, W], dtype=torch.int64, device=device).expand(
|
||||
N, 3
|
||||
)
|
||||
mask = torch.ones((N, P), device=device)
|
||||
point_weight = 19.0
|
||||
|
||||
volume_densities_, volume_features_ = _points_to_volumes(
|
||||
points_3d,
|
||||
points_features,
|
||||
volume_densities,
|
||||
volume_features,
|
||||
grid_sizes,
|
||||
point_weight,
|
||||
mask,
|
||||
align_corners,
|
||||
splat,
|
||||
)
|
||||
|
||||
self.assertIs(volume_densities, volume_densities_)
|
||||
self.assertIs(volume_features, volume_features_)
|
||||
|
||||
if align_corners:
|
||||
volume_densities_expected[0, 0, 2, 3, 4] = point_weight
|
||||
else:
|
||||
volume_densities_expected[0, 0, 1, 3, 3] = point_weight
|
||||
|
||||
self.assertClose(volume_densities, volume_densities_expected)
|
||||
|
||||
Reference in New Issue
Block a user