CPU function for points2vols

Summary: Single C++ function for the core of points2vols, not used anywhere yet. Added ability to control align_corners and the weight of each point, which may be useful later.

Reviewed By: nikhilaravi

Differential Revision: D29548607

fbshipit-source-id: a5cda7ec2c14836624e7dfe744c4bbb3f3d3dfe2
This commit is contained in:
Jeremy Reizenstein 2021-10-01 11:57:07 -07:00 committed by Facebook GitHub Bot
parent c7c6deab86
commit 0dfc6e0eb8
5 changed files with 767 additions and 0 deletions

View File

@ -25,6 +25,7 @@
#include "mesh_normal_consistency/mesh_normal_consistency.h"
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
#include "point_mesh/point_mesh_cuda.h"
#include "points_to_volumes/points_to_volumes.h"
#include "rasterize_meshes/rasterize_meshes.h"
#include "rasterize_points/rasterize_points.h"
#include "sample_farthest_points/sample_farthest_points.h"
@ -47,6 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
m.def("gather_scatter", &GatherScatter);
m.def("points_to_volumes_forward", PointsToVolumesForward);
m.def("points_to_volumes_backward", PointsToVolumesBackward);
m.def("rasterize_points", &RasterizePoints);
m.def("rasterize_points_backward", &RasterizePointsBackward);
m.def("rasterize_meshes_backward", &RasterizeMeshesBackward);

View File

@ -0,0 +1,135 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
/*
volume_features and volume_densities are modified in place.
Args:
points_3d: Batch of 3D point cloud coordinates of shape
`(minibatch, N, 3)` where N is the number of points
in each point cloud. Coordinates have to be specified in the
local volume coordinates (ranging in [-1, 1]).
points_features: Features of shape `(minibatch, N, feature_dim)`
corresponding to the points of the input point cloud `points_3d`.
volume_features: Batch of input feature volumes
of shape `(minibatch, feature_dim, D, H, W)`
volume_densities: Batch of input feature volume densities
of shape `(minibatch, 1, D, H, W)`. Each voxel should
contain a non-negative number corresponding to its
opaqueness (the higher, the less transparent).
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
spatial resolutions of each of the the non-flattened `volumes`
tensors. Note that the following has to hold:
`torch.prod(grid_sizes, dim=1)==N_voxels`.
point_weight: A scalar controlling how much weight a single point has.
mask: A binary mask of shape `(minibatch, N)` determining
which 3D points are going to be converted to the resulting
volume. Set to `None` if all points are valid.
align_corners: as for grid_sample.
splat: if true, trilinear interpolation. If false all the weight goes in
the nearest voxel.
*/
void PointsToVolumesForwardCpu(
const torch::Tensor& points_3d,
const torch::Tensor& points_features,
const torch::Tensor& volume_densities,
const torch::Tensor& volume_features,
const torch::Tensor& grid_sizes,
const torch::Tensor& mask,
float point_weight,
bool align_corners,
bool splat);
inline void PointsToVolumesForward(
const torch::Tensor& points_3d,
const torch::Tensor& points_features,
const torch::Tensor& volume_densities,
const torch::Tensor& volume_features,
const torch::Tensor& grid_sizes,
const torch::Tensor& mask,
float point_weight,
bool align_corners,
bool splat) {
if (points_3d.is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("CUDA not implemented yet");
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
PointsToVolumesForwardCpu(
points_3d,
points_features,
volume_densities,
volume_features,
grid_sizes,
mask,
point_weight,
align_corners,
splat);
}
// grad_points_3d and grad_points_features are modified in place.
void PointsToVolumesBackwardCpu(
const torch::Tensor& points_3d,
const torch::Tensor& points_features,
const torch::Tensor& grid_sizes,
const torch::Tensor& mask,
float point_weight,
bool align_corners,
bool splat,
const torch::Tensor& grad_volume_densities,
const torch::Tensor& grad_volume_features,
const torch::Tensor& grad_points_3d,
const torch::Tensor& grad_points_features);
inline void PointsToVolumesBackward(
const torch::Tensor& points_3d,
const torch::Tensor& points_features,
const torch::Tensor& grid_sizes,
const torch::Tensor& mask,
float point_weight,
bool align_corners,
bool splat,
const torch::Tensor& grad_volume_densities,
const torch::Tensor& grad_volume_features,
const torch::Tensor& grad_points_3d,
const torch::Tensor& grad_points_features) {
if (points_3d.is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("CUDA not implemented yet");
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
PointsToVolumesBackwardCpu(
points_3d,
points_features,
grid_sizes,
mask,
point_weight,
align_corners,
splat,
grad_volume_densities,
grad_volume_features,
grad_points_3d,
grad_points_features);
}

View File

@ -0,0 +1,316 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h>
#include <algorithm>
#include <cmath>
#include <thread>
#include <vector>
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to
// points px in [-1, 1]. There are two ways to do this.
// If align_corners=True, px=-1 is the exact location 0 and px=1 is the exact
// location grid_size_x - 1.
// So the location of px is {(px + 1) * 0.5} * (grid_size_x - 1).
// Note that if you generate random points within the bounds you are less likely
// to hit the edge locations than other locations.
// This can be thought of as saying "location i" means a specific point.
// If align_corners=False, px=-1 is half way between the exact location 0 and
// the non-existent location -1, i.e. location -0.5.
// Similarly px=1 is is half way between the exact location grid_size_x-1 and
// the non-existent location grid_size, i.e. the location grid_size_x - 0.5.
// So the location of px is ({(px + 1) * 0.5} * grid_size_x) - 0.5.
// Note that if you generate random points within the bounds you are equally
// likely to hit any location.
// This can be thought of as saying "location i" means the whole box from
// (i-0.5) to (i+0.5)
// EightDirections(t) runs t(a,b,c) for every combination of boolean a, b, c.
template <class T>
static void EightDirections(T&& t) {
t(false, false, false);
t(false, false, true);
t(false, true, false);
t(false, true, true);
t(true, false, false);
t(true, false, true);
t(true, true, false);
t(true, true, true);
}
void PointsToVolumesForwardCpu(
const torch::Tensor& points_3d,
const torch::Tensor& points_features,
const torch::Tensor& volume_densities,
const torch::Tensor& volume_features,
const torch::Tensor& grid_sizes,
const torch::Tensor& mask,
const float point_weight,
const bool align_corners,
const bool splat) {
const int64_t batch_size = points_3d.size(0);
const int64_t P = points_3d.size(1);
const int64_t n_features = points_features.size(2);
// We unify the formula for the location of px in the comment above as
// ({(px + 1) * 0.5} * (grid_size_x-scale_offset)) - offset.
const int scale_offset = align_corners ? 1 : 0;
const float offset = align_corners ? 0 : 0.5;
auto points_3d_a = points_3d.accessor<float, 3>();
auto points_features_a = points_features.accessor<float, 3>();
auto volume_densities_a = volume_densities.accessor<float, 5>();
auto volume_features_a = volume_features.accessor<float, 5>();
auto grid_sizes_a = grid_sizes.accessor<int64_t, 2>();
auto mask_a = mask.accessor<float, 2>();
// For each batch element
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
auto points_3d_aa = points_3d_a[batch_idx];
auto points_features_aa = points_features_a[batch_idx];
auto volume_densities_aa = volume_densities_a[batch_idx][0];
auto volume_features_aa = volume_features_a[batch_idx];
auto grid_sizes_aa = grid_sizes_a[batch_idx];
auto mask_aa = mask_a[batch_idx];
const int64_t grid_size_x = grid_sizes_aa[2];
const int64_t grid_size_y = grid_sizes_aa[1];
const int64_t grid_size_z = grid_sizes_aa[0];
// For each point
for (int64_t point_idx = 0; point_idx < P; ++point_idx) {
// Ignore point if mask is 0
if (mask_aa[point_idx] == 0) {
continue;
}
auto point = points_3d_aa[point_idx];
auto point_features = points_features_aa[point_idx];
// Define how to increment a location in the volume by an amount. The need
// for this depends on the interpolation method:
// once per point for nearest, eight times for splat.
auto increment_location =
[&](int64_t x, int64_t y, int64_t z, float weight) {
if (x >= grid_size_x || y >= grid_size_y || z >= grid_size_z) {
return;
}
if (x < 0 || y < 0 || z < 0) {
return;
}
volume_densities_aa[z][y][x] += weight * point_weight;
for (int64_t feature_idx = 0; feature_idx < n_features;
++feature_idx) {
volume_features_aa[feature_idx][z][y][x] +=
point_features[feature_idx] * weight * point_weight;
}
};
if (!splat) {
// Increment the location nearest the point.
long x = std::lround(
(point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset);
long y = std::lround(
(point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset);
long z = std::lround(
(point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset);
increment_location(x, y, z, 1);
} else {
// There are 8 locations around the point which we need to worry about.
// Their coordinates are (x or x+1, y or y+1, z or z+1).
// rx is a number between 0 and 1 for the proportion in the x direction:
// rx==0 means weight all on the lower bound, x, rx=1-eps means most
// weight on x+1. Ditto for ry and yz.
float x = 0, y = 0, z = 0;
float rx = std::modf(
(point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset, &x);
float ry = std::modf(
(point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset, &y);
float rz = std::modf(
(point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset, &z);
// Define how to fractionally increment one of the 8 locations around
// the point.
auto handle_point = [&](bool up_x, bool up_y, bool up_z) {
float weight = (up_x ? rx : 1 - rx) * (up_y ? ry : 1 - ry) *
(up_z ? rz : 1 - rz);
increment_location(x + up_x, y + up_y, z + up_z, weight);
};
// and do so.
EightDirections(handle_point);
}
}
}
}
// With nearest, the only smooth dependence is that volume features
// depend on points features.
//
// With splat, the dependencies are as follows, with gradients passing
// in the opposite direction.
//
// points_3d points_features
// │ │ │
// │ │ │
// │ └───────────┐ │
// │ │ │
// │ │ │
// ▼ ▼ ▼
// volume_densities volume_features
// It is also the case that the input volume_densities and
// volume_features affect the corresponding outputs (they are
// modified in place).
// But the forward pass just increments these by a value which
// does not depend on them. So our autograd backwards pass needs
// to copy the gradient for each of those outputs to the
// corresponding input. We just do that in the Python layer.
void PointsToVolumesBackwardCpu(
const torch::Tensor& points_3d,
const torch::Tensor& points_features,
const torch::Tensor& grid_sizes,
const torch::Tensor& mask,
const float point_weight,
const bool align_corners,
const bool splat,
const torch::Tensor& grad_volume_densities,
const torch::Tensor& grad_volume_features,
const torch::Tensor& grad_points_3d,
const torch::Tensor& grad_points_features) {
const int64_t batch_size = points_3d.size(0);
const int64_t P = points_3d.size(1);
const int64_t n_features = grad_points_features.size(2);
const int scale_offset = align_corners ? 1 : 0;
const float offset = align_corners ? 0 : 0.5;
auto points_3d_a = points_3d.accessor<float, 3>();
auto points_features_a = points_features.accessor<float, 3>();
auto grid_sizes_a = grid_sizes.accessor<int64_t, 2>();
auto mask_a = mask.accessor<float, 2>();
auto grad_volume_densities_a = grad_volume_densities.accessor<float, 5>();
auto grad_volume_features_a = grad_volume_features.accessor<float, 5>();
auto grad_points_3d_a = grad_points_3d.accessor<float, 3>();
auto grad_points_features_a = grad_points_features.accessor<float, 3>();
// For each batch element
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
auto points_3d_aa = points_3d_a[batch_idx];
auto points_features_aa = points_features_a[batch_idx];
auto grid_sizes_aa = grid_sizes_a[batch_idx];
auto mask_aa = mask_a[batch_idx];
auto grad_volume_densities_aa = grad_volume_densities_a[batch_idx][0];
auto grad_volume_features_aa = grad_volume_features_a[batch_idx];
auto grad_points_3d_aa = grad_points_3d_a[batch_idx];
auto grad_points_features_aa = grad_points_features_a[batch_idx];
const int64_t grid_size_x = grid_sizes_aa[2];
const int64_t grid_size_y = grid_sizes_aa[1];
const int64_t grid_size_z = grid_sizes_aa[0];
// For each point
for (int64_t point_idx = 0; point_idx < P; ++point_idx) {
if (mask_aa[point_idx] == 0) {
continue;
}
auto point = points_3d_aa[point_idx];
auto point_features = points_features_aa[point_idx];
auto grad_point_features = grad_points_features_aa[point_idx];
auto grad_point = grad_points_3d_aa[point_idx];
// Define how to (backwards) increment a location in the point cloud,
// to take gradients to the features.
// We return false if the location does not really exist, so there was
// nothing to do.
// This happens once per point for nearest, eight times for splat.
auto increment_location =
[&](int64_t x, int64_t y, int64_t z, float weight) {
if (x >= grid_size_x || y >= grid_size_y || z >= grid_size_z) {
return false;
}
if (x < 0 || y < 0 || z < 0) {
return false;
}
for (int64_t feature_idx = 0; feature_idx < n_features;
++feature_idx) {
// This is a forward line, for comparison
// volume_features_aa[feature_idx][z][y][x] +=
// point_features[feature_idx] * weight * point_weight;
grad_point_features[feature_idx] +=
grad_volume_features_aa[feature_idx][z][y][x] * weight *
point_weight;
}
return true;
};
if (!splat) {
long x = std::lround(
(point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset);
long y = std::lround(
(point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset);
long z = std::lround(
(point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset);
increment_location(x, y, z, 1);
} else {
float x = 0, y = 0, z = 0;
float rx = std::modf(
(point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset, &x);
float ry = std::modf(
(point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset, &y);
float rz = std::modf(
(point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset, &z);
auto handle_point = [&](bool up_x, bool up_y, bool up_z) {
float weight_x = (up_x ? rx : 1 - rx);
float weight_y = (up_y ? ry : 1 - ry);
float weight_z = (up_z ? rz : 1 - rz);
float weight = weight_x * weight_y * weight_z;
// For each of the eight locations, we first increment the feature
// gradient.
if (increment_location(x + up_x, y + up_y, z + up_z, weight)) {
// If the location is a real location, we also (in this splat
// case) need to update the gradient w.r.t. the point position.
// - the amount in this location is controlled by the weight.
// There are two contributions:
// (1) The point position affects how much density we added
// to the location's density, so we have a contribution
// from grad_volume_density. Specifically,
// weight * point_weight has been added to
// volume_densities_aa[z+up_z][y+up_y][x+up_x]
//
// (2) The point position affects how much of each of the
// point's features were added to the corresponding feature
// of this location, so we have a contribution from
// grad_volume_features. Specifically, for each feature_idx,
// point_features[feature_idx] * weight * point_weight
// has been added to
// volume_features_aa[feature_idx][z+up_z][y+up_y][x+up_x]
float source_gradient =
grad_volume_densities_aa[z + up_z][y + up_y][x + up_x];
for (int64_t feature_idx = 0; feature_idx < n_features;
++feature_idx) {
source_gradient += point_features[feature_idx] *
grad_volume_features_aa[feature_idx][z + up_z][y + up_y]
[x + up_x];
}
grad_point[0] += source_gradient * (up_x ? 1 : -1) * weight_y *
weight_z * 0.5 * (grid_size_x - scale_offset) * point_weight;
grad_point[1] += source_gradient * (up_y ? 1 : -1) * weight_x *
weight_z * 0.5 * (grid_size_y - scale_offset) * point_weight;
grad_point[2] += source_gradient * (up_z ? 1 : -1) * weight_x *
weight_y * 0.5 * (grid_size_z - scale_offset) * point_weight;
}
};
EightDirections(handle_point);
}
}
}
}

View File

@ -7,12 +7,186 @@
from typing import TYPE_CHECKING, Optional, Tuple
import torch
from pytorch3d import _C
from torch.autograd import Function
from torch.autograd.function import once_differentiable
if TYPE_CHECKING:
from ..structures import Pointclouds, Volumes
class _points_to_volumes_function(Function):
"""
For each point in a pointcloud, add point_weight to the
corresponding volume density and point_weight times its features
to the corresponding volume features.
This function does not require any contiguity internally and therefore
doesn't need to make copies of its inputs, which is useful when GPU memory
is at a premium. (An implementation requiring contiguous inputs might be faster
though). The volumes are modified in place.
This function is differentiable with respect to
points_features, volume_densities and volume_features.
If splat is True then it is also differentiable with respect to
points_3d.
It may be useful to think about this function as a sort of opposite to
torch.nn.functional.grid_sample with 5D inputs.
Args:
points_3d: Batch of 3D point cloud coordinates of shape
`(minibatch, N, 3)` where N is the number of points
in each point cloud. Coordinates have to be specified in the
local volume coordinates (ranging in [-1, 1]).
points_features: Features of shape `(minibatch, N, feature_dim)`
corresponding to the points of the input point cloud `points_3d`.
volume_features: Batch of input feature volumes
of shape `(minibatch, feature_dim, D, H, W)`
volume_densities: Batch of input feature volume densities
of shape `(minibatch, 1, D, H, W)`. Each voxel should
contain a non-negative number corresponding to its
opaqueness (the higher, the less transparent).
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
spatial resolutions of each of the the non-flattened `volumes`
tensors. Note that the following has to hold:
`torch.prod(grid_sizes, dim=1)==N_voxels`.
point_weight: A scalar controlling how much weight a single point has.
mask: A binary mask of shape `(minibatch, N)` determining
which 3D points are going to be converted to the resulting
volume. Set to `None` if all points are valid.
align_corners: as for grid_sample.
splat: if true, trilinear interpolation. If false all the weight goes in
the nearest voxel.
Returns:
volume_densities and volume_features, which have been modified in place.
"""
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
ctx,
points_3d: torch.Tensor,
points_features: torch.Tensor,
volume_densities: torch.Tensor,
volume_features: torch.Tensor,
grid_sizes: torch.LongTensor,
point_weight: float,
mask: torch.Tensor,
align_corners: bool,
splat: bool,
):
ctx.mark_dirty(volume_densities, volume_features)
N, P, D = points_3d.shape
if D != 3:
raise ValueError("points_3d must be 3D")
if points_3d.dtype != torch.float32:
raise ValueError("points_3d must be float32")
if points_features.dtype != torch.float32:
raise ValueError("points_features must be float32")
N1, P1, C = points_features.shape
if N1 != N or P1 != P:
raise ValueError("Bad points_features shape")
if volume_densities.dtype != torch.float32:
raise ValueError("volume_densities must be float32")
N2, one, D, H, W = volume_densities.shape
if N2 != N or one != 1:
raise ValueError("Bad volume_densities shape")
if volume_features.dtype != torch.float32:
raise ValueError("volume_features must be float32")
N3, C1, D1, H1, W1 = volume_features.shape
if N3 != N or C1 != C or D1 != D or H1 != H or W1 != W:
raise ValueError("Bad volume_features shape")
if grid_sizes.dtype != torch.int64:
raise ValueError("grid_sizes must be int64")
N4, D1 = grid_sizes.shape
if N4 != N or D1 != 3:
raise ValueError("Bad grid_sizes.shape")
if mask.dtype != torch.float32:
raise ValueError("mask must be float32")
N5, P2 = mask.shape
if N5 != N or P2 != P:
raise ValueError("Bad mask shape")
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
_C.points_to_volumes_forward(
points_3d,
points_features,
volume_densities,
volume_features,
grid_sizes,
mask,
point_weight,
align_corners,
splat,
)
if splat:
ctx.save_for_backward(points_3d, points_features, grid_sizes, mask)
else:
ctx.save_for_backward(points_3d, grid_sizes, mask)
ctx.point_weight = point_weight
ctx.splat = splat
ctx.align_corners = align_corners
return volume_densities, volume_features
@staticmethod
@once_differentiable
def backward(ctx, grad_volume_densities, grad_volume_features):
splat = ctx.splat
N, C = grad_volume_features.shape[:2]
if splat:
points_3d, points_features, grid_sizes, mask = ctx.saved_tensors
P = points_3d.shape[1]
grad_points_3d = torch.zeros_like(points_3d)
else:
points_3d, grid_sizes, mask = ctx.saved_tensors
P = points_3d.shape[1]
ones = points_3d.new_zeros(1, 1, 1)
# There is no gradient. Just need something to let its accessors exist.
grad_points_3d = ones.expand_as(points_3d)
# points_features not needed. Just need something to let its accessors exist.
points_features = ones.expand(N, P, C)
grad_points_features = points_3d.new_zeros(N, P, C)
_C.points_to_volumes_backward(
points_3d,
points_features,
grid_sizes,
mask,
ctx.point_weight,
ctx.align_corners,
splat,
grad_volume_densities,
grad_volume_features,
grad_points_3d,
grad_points_features,
)
return (
(grad_points_3d if splat else None),
grad_points_features,
grad_volume_densities,
grad_volume_features,
None,
None,
None,
None,
None,
)
# pyre-fixme[16]: `_points_to_volumes_function` has no attribute `apply`.
_points_to_volumes = _points_to_volumes_function.apply
def add_pointclouds_to_volumes(
pointclouds: "Pointclouds",
initial_volumes: "Volumes",

View File

@ -5,12 +5,14 @@
# LICENSE file in the root directory of this source tree.
import unittest
from functools import partial
from typing import Tuple
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import add_pointclouds_to_volumes
from pytorch3d.ops.points_to_volumes import _points_to_volumes
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes
from pytorch3d.structures.pointclouds import Pointclouds
@ -395,3 +397,140 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
# check that all per-slice avg errors vanish
self.assertClose(clr_diff, torch.zeros_like(clr_diff), atol=1e-2)
class TestRawFunction(TestCaseMixin, unittest.TestCase):
"""
Testing the _C.points_to_volumes function through its wrapper
_points_to_volumes.
"""
def setUp(self) -> None:
torch.manual_seed(42)
def test_grad_corners_splat_cpu(self):
self.do_gradcheck(torch.device("cpu"), True, True)
def test_grad_corners_round_cpu(self):
self.do_gradcheck(torch.device("cpu"), False, True)
def test_grad_splat_cpu(self):
self.do_gradcheck(torch.device("cpu"), True, False)
def test_grad_round_cpu(self):
self.do_gradcheck(torch.device("cpu"), False, False)
def do_gradcheck(self, device, splat: bool, align_corners: bool):
"""
Use gradcheck to verify the gradient of _points_to_volumes
with random input.
"""
N, C, D, H, W, P = 2, 4, 5, 6, 7, 5
points_3d = (
torch.rand((N, P, 3), device=device, dtype=torch.float64) * 0.8 + 0.1
)
points_features = torch.rand((N, P, C), device=device, dtype=torch.float64)
volume_densities = torch.zeros((N, 1, D, H, W), device=device)
volume_features = torch.zeros((N, C, D, H, W), device=device)
volume_densities_scale = torch.rand_like(volume_densities)
volume_features_scale = torch.rand_like(volume_features)
grid_sizes = torch.tensor([D, H, W], dtype=torch.int64, device=device).expand(
N, 3
)
mask = torch.ones((N, P), device=device)
mask[:, 0] = 0
align_corners = False
def f(points_3d_, points_features_):
(volume_densities_, volume_features_) = _points_to_volumes(
points_3d_.to(torch.float32),
points_features_.to(torch.float32),
volume_densities.clone(),
volume_features.clone(),
grid_sizes,
2.0,
mask,
align_corners,
splat,
)
density = (volume_densities_ * volume_densities_scale).sum()
features = (volume_features_ * volume_features_scale).sum()
return density, features
base = f(points_3d.clone(), points_features.clone())
self.assertGreater(base[0], 0)
self.assertGreater(base[1], 0)
points_features.requires_grad = True
if splat:
points_3d.requires_grad = True
torch.autograd.gradcheck(
f,
(points_3d, points_features),
check_undefined_grad=False,
eps=2e-4,
atol=0.01,
)
else:
torch.autograd.gradcheck(
partial(f, points_3d),
points_features,
check_undefined_grad=False,
eps=2e-3,
atol=0.001,
)
def test_single_corners_round_cpu(self):
self.single_point(torch.device("cpu"), False, True)
def test_single_corners_splat_cpu(self):
self.single_point(torch.device("cpu"), True, True)
def test_single_round_cpu(self):
self.single_point(torch.device("cpu"), False, False)
def test_single_splat_cpu(self):
self.single_point(torch.device("cpu"), True, False)
def single_point(self, device, splat: bool, align_corners: bool):
"""
Check the outcome of _points_to_volumes where a single point
exists which lines up with a single voxel.
"""
D, H, W = (6, 6, 11) if align_corners else (5, 5, 10)
N, C, P = 1, 1, 1
if align_corners:
points_3d = torch.tensor([[[-0.2, 0.2, -0.2]]], device=device)
else:
points_3d = torch.tensor([[[-0.3, 0.4, -0.4]]], device=device)
points_features = torch.zeros((N, P, C), device=device)
volume_densities = torch.zeros((N, 1, D, H, W), device=device)
volume_densities_expected = torch.zeros((N, 1, D, H, W), device=device)
volume_features = torch.zeros((N, C, D, H, W), device=device)
grid_sizes = torch.tensor([D, H, W], dtype=torch.int64, device=device).expand(
N, 3
)
mask = torch.ones((N, P), device=device)
point_weight = 19.0
volume_densities_, volume_features_ = _points_to_volumes(
points_3d,
points_features,
volume_densities,
volume_features,
grid_sizes,
point_weight,
mask,
align_corners,
splat,
)
self.assertIs(volume_densities, volume_densities_)
self.assertIs(volume_features, volume_features_)
if align_corners:
volume_densities_expected[0, 0, 2, 3, 4] = point_weight
else:
volume_densities_expected[0, 0, 1, 3, 3] = point_weight
self.assertClose(volume_densities, volume_densities_expected)