From 0dfc6e0eb8a252878784dc9ae749d5298c5830b2 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 1 Oct 2021 11:57:07 -0700 Subject: [PATCH] CPU function for points2vols Summary: Single C++ function for the core of points2vols, not used anywhere yet. Added ability to control align_corners and the weight of each point, which may be useful later. Reviewed By: nikhilaravi Differential Revision: D29548607 fbshipit-source-id: a5cda7ec2c14836624e7dfe744c4bbb3f3d3dfe2 --- pytorch3d/csrc/ext.cpp | 3 + .../points_to_volumes/points_to_volumes.h | 135 ++++++++ .../points_to_volumes_cpu.cpp | 316 ++++++++++++++++++ pytorch3d/ops/points_to_volumes.py | 174 ++++++++++ tests/test_points_to_volumes.py | 139 ++++++++ 5 files changed, 767 insertions(+) create mode 100644 pytorch3d/csrc/points_to_volumes/points_to_volumes.h create mode 100644 pytorch3d/csrc/points_to_volumes/points_to_volumes_cpu.cpp diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 449cb14b..6fd6d0c6 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -25,6 +25,7 @@ #include "mesh_normal_consistency/mesh_normal_consistency.h" #include "packed_to_padded_tensor/packed_to_padded_tensor.h" #include "point_mesh/point_mesh_cuda.h" +#include "points_to_volumes/points_to_volumes.h" #include "rasterize_meshes/rasterize_meshes.h" #include "rasterize_points/rasterize_points.h" #include "sample_farthest_points/sample_farthest_points.h" @@ -47,6 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices); m.def("gather_scatter", &GatherScatter); + m.def("points_to_volumes_forward", PointsToVolumesForward); + m.def("points_to_volumes_backward", PointsToVolumesBackward); m.def("rasterize_points", &RasterizePoints); m.def("rasterize_points_backward", &RasterizePointsBackward); m.def("rasterize_meshes_backward", &RasterizeMeshesBackward); diff --git a/pytorch3d/csrc/points_to_volumes/points_to_volumes.h b/pytorch3d/csrc/points_to_volumes/points_to_volumes.h new file mode 100644 index 00000000..a53c2ab4 --- /dev/null +++ b/pytorch3d/csrc/points_to_volumes/points_to_volumes.h @@ -0,0 +1,135 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include "utils/pytorch3d_cutils.h" + +/* + volume_features and volume_densities are modified in place. + + Args: + points_3d: Batch of 3D point cloud coordinates of shape + `(minibatch, N, 3)` where N is the number of points + in each point cloud. Coordinates have to be specified in the + local volume coordinates (ranging in [-1, 1]). + points_features: Features of shape `(minibatch, N, feature_dim)` + corresponding to the points of the input point cloud `points_3d`. + volume_features: Batch of input feature volumes + of shape `(minibatch, feature_dim, D, H, W)` + volume_densities: Batch of input feature volume densities + of shape `(minibatch, 1, D, H, W)`. Each voxel should + contain a non-negative number corresponding to its + opaqueness (the higher, the less transparent). + + grid_sizes: `LongTensor` of shape (minibatch, 3) representing the + spatial resolutions of each of the the non-flattened `volumes` + tensors. Note that the following has to hold: + `torch.prod(grid_sizes, dim=1)==N_voxels`. + + point_weight: A scalar controlling how much weight a single point has. + + mask: A binary mask of shape `(minibatch, N)` determining + which 3D points are going to be converted to the resulting + volume. Set to `None` if all points are valid. + + align_corners: as for grid_sample. + + splat: if true, trilinear interpolation. If false all the weight goes in + the nearest voxel. +*/ + +void PointsToVolumesForwardCpu( + const torch::Tensor& points_3d, + const torch::Tensor& points_features, + const torch::Tensor& volume_densities, + const torch::Tensor& volume_features, + const torch::Tensor& grid_sizes, + const torch::Tensor& mask, + float point_weight, + bool align_corners, + bool splat); + +inline void PointsToVolumesForward( + const torch::Tensor& points_3d, + const torch::Tensor& points_features, + const torch::Tensor& volume_densities, + const torch::Tensor& volume_features, + const torch::Tensor& grid_sizes, + const torch::Tensor& mask, + float point_weight, + bool align_corners, + bool splat) { + if (points_3d.is_cuda()) { +#ifdef WITH_CUDA + AT_ERROR("CUDA not implemented yet"); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + PointsToVolumesForwardCpu( + points_3d, + points_features, + volume_densities, + volume_features, + grid_sizes, + mask, + point_weight, + align_corners, + splat); +} + +// grad_points_3d and grad_points_features are modified in place. + +void PointsToVolumesBackwardCpu( + const torch::Tensor& points_3d, + const torch::Tensor& points_features, + const torch::Tensor& grid_sizes, + const torch::Tensor& mask, + float point_weight, + bool align_corners, + bool splat, + const torch::Tensor& grad_volume_densities, + const torch::Tensor& grad_volume_features, + const torch::Tensor& grad_points_3d, + const torch::Tensor& grad_points_features); + +inline void PointsToVolumesBackward( + const torch::Tensor& points_3d, + const torch::Tensor& points_features, + const torch::Tensor& grid_sizes, + const torch::Tensor& mask, + float point_weight, + bool align_corners, + bool splat, + const torch::Tensor& grad_volume_densities, + const torch::Tensor& grad_volume_features, + const torch::Tensor& grad_points_3d, + const torch::Tensor& grad_points_features) { + if (points_3d.is_cuda()) { +#ifdef WITH_CUDA + AT_ERROR("CUDA not implemented yet"); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + PointsToVolumesBackwardCpu( + points_3d, + points_features, + grid_sizes, + mask, + point_weight, + align_corners, + splat, + grad_volume_densities, + grad_volume_features, + grad_points_3d, + grad_points_features); +} diff --git a/pytorch3d/csrc/points_to_volumes/points_to_volumes_cpu.cpp b/pytorch3d/csrc/points_to_volumes/points_to_volumes_cpu.cpp new file mode 100644 index 00000000..532ee31c --- /dev/null +++ b/pytorch3d/csrc/points_to_volumes/points_to_volumes_cpu.cpp @@ -0,0 +1,316 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +// In the x direction, the location {0, ..., grid_size_x - 1} correspond to +// points px in [-1, 1]. There are two ways to do this. + +// If align_corners=True, px=-1 is the exact location 0 and px=1 is the exact +// location grid_size_x - 1. +// So the location of px is {(px + 1) * 0.5} * (grid_size_x - 1). +// Note that if you generate random points within the bounds you are less likely +// to hit the edge locations than other locations. +// This can be thought of as saying "location i" means a specific point. + +// If align_corners=False, px=-1 is half way between the exact location 0 and +// the non-existent location -1, i.e. location -0.5. +// Similarly px=1 is is half way between the exact location grid_size_x-1 and +// the non-existent location grid_size, i.e. the location grid_size_x - 0.5. +// So the location of px is ({(px + 1) * 0.5} * grid_size_x) - 0.5. +// Note that if you generate random points within the bounds you are equally +// likely to hit any location. +// This can be thought of as saying "location i" means the whole box from +// (i-0.5) to (i+0.5) + +// EightDirections(t) runs t(a,b,c) for every combination of boolean a, b, c. +template +static void EightDirections(T&& t) { + t(false, false, false); + t(false, false, true); + t(false, true, false); + t(false, true, true); + t(true, false, false); + t(true, false, true); + t(true, true, false); + t(true, true, true); +} + +void PointsToVolumesForwardCpu( + const torch::Tensor& points_3d, + const torch::Tensor& points_features, + const torch::Tensor& volume_densities, + const torch::Tensor& volume_features, + const torch::Tensor& grid_sizes, + const torch::Tensor& mask, + const float point_weight, + const bool align_corners, + const bool splat) { + const int64_t batch_size = points_3d.size(0); + const int64_t P = points_3d.size(1); + const int64_t n_features = points_features.size(2); + + // We unify the formula for the location of px in the comment above as + // ({(px + 1) * 0.5} * (grid_size_x-scale_offset)) - offset. + const int scale_offset = align_corners ? 1 : 0; + const float offset = align_corners ? 0 : 0.5; + + auto points_3d_a = points_3d.accessor(); + auto points_features_a = points_features.accessor(); + auto volume_densities_a = volume_densities.accessor(); + auto volume_features_a = volume_features.accessor(); + auto grid_sizes_a = grid_sizes.accessor(); + auto mask_a = mask.accessor(); + + // For each batch element + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + auto points_3d_aa = points_3d_a[batch_idx]; + auto points_features_aa = points_features_a[batch_idx]; + auto volume_densities_aa = volume_densities_a[batch_idx][0]; + auto volume_features_aa = volume_features_a[batch_idx]; + auto grid_sizes_aa = grid_sizes_a[batch_idx]; + auto mask_aa = mask_a[batch_idx]; + + const int64_t grid_size_x = grid_sizes_aa[2]; + const int64_t grid_size_y = grid_sizes_aa[1]; + const int64_t grid_size_z = grid_sizes_aa[0]; + + // For each point + for (int64_t point_idx = 0; point_idx < P; ++point_idx) { + // Ignore point if mask is 0 + if (mask_aa[point_idx] == 0) { + continue; + } + auto point = points_3d_aa[point_idx]; + auto point_features = points_features_aa[point_idx]; + + // Define how to increment a location in the volume by an amount. The need + // for this depends on the interpolation method: + // once per point for nearest, eight times for splat. + auto increment_location = + [&](int64_t x, int64_t y, int64_t z, float weight) { + if (x >= grid_size_x || y >= grid_size_y || z >= grid_size_z) { + return; + } + if (x < 0 || y < 0 || z < 0) { + return; + } + + volume_densities_aa[z][y][x] += weight * point_weight; + + for (int64_t feature_idx = 0; feature_idx < n_features; + ++feature_idx) { + volume_features_aa[feature_idx][z][y][x] += + point_features[feature_idx] * weight * point_weight; + } + }; + + if (!splat) { + // Increment the location nearest the point. + long x = std::lround( + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset); + long y = std::lround( + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset); + long z = std::lround( + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset); + increment_location(x, y, z, 1); + } else { + // There are 8 locations around the point which we need to worry about. + // Their coordinates are (x or x+1, y or y+1, z or z+1). + // rx is a number between 0 and 1 for the proportion in the x direction: + // rx==0 means weight all on the lower bound, x, rx=1-eps means most + // weight on x+1. Ditto for ry and yz. + float x = 0, y = 0, z = 0; + float rx = std::modf( + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset, &x); + float ry = std::modf( + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset, &y); + float rz = std::modf( + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset, &z); + // Define how to fractionally increment one of the 8 locations around + // the point. + auto handle_point = [&](bool up_x, bool up_y, bool up_z) { + float weight = (up_x ? rx : 1 - rx) * (up_y ? ry : 1 - ry) * + (up_z ? rz : 1 - rz); + increment_location(x + up_x, y + up_y, z + up_z, weight); + }; + // and do so. + EightDirections(handle_point); + } + } + } +} + +// With nearest, the only smooth dependence is that volume features +// depend on points features. +// +// With splat, the dependencies are as follows, with gradients passing +// in the opposite direction. +// +// points_3d points_features +// │ │ │ +// │ │ │ +// │ └───────────┐ │ +// │ │ │ +// │ │ │ +// ▼ ▼ ▼ +// volume_densities volume_features + +// It is also the case that the input volume_densities and +// volume_features affect the corresponding outputs (they are +// modified in place). +// But the forward pass just increments these by a value which +// does not depend on them. So our autograd backwards pass needs +// to copy the gradient for each of those outputs to the +// corresponding input. We just do that in the Python layer. + +void PointsToVolumesBackwardCpu( + const torch::Tensor& points_3d, + const torch::Tensor& points_features, + const torch::Tensor& grid_sizes, + const torch::Tensor& mask, + const float point_weight, + const bool align_corners, + const bool splat, + const torch::Tensor& grad_volume_densities, + const torch::Tensor& grad_volume_features, + const torch::Tensor& grad_points_3d, + const torch::Tensor& grad_points_features) { + const int64_t batch_size = points_3d.size(0); + const int64_t P = points_3d.size(1); + const int64_t n_features = grad_points_features.size(2); + const int scale_offset = align_corners ? 1 : 0; + const float offset = align_corners ? 0 : 0.5; + + auto points_3d_a = points_3d.accessor(); + auto points_features_a = points_features.accessor(); + auto grid_sizes_a = grid_sizes.accessor(); + auto mask_a = mask.accessor(); + auto grad_volume_densities_a = grad_volume_densities.accessor(); + auto grad_volume_features_a = grad_volume_features.accessor(); + auto grad_points_3d_a = grad_points_3d.accessor(); + auto grad_points_features_a = grad_points_features.accessor(); + + // For each batch element + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + auto points_3d_aa = points_3d_a[batch_idx]; + auto points_features_aa = points_features_a[batch_idx]; + auto grid_sizes_aa = grid_sizes_a[batch_idx]; + auto mask_aa = mask_a[batch_idx]; + auto grad_volume_densities_aa = grad_volume_densities_a[batch_idx][0]; + auto grad_volume_features_aa = grad_volume_features_a[batch_idx]; + auto grad_points_3d_aa = grad_points_3d_a[batch_idx]; + auto grad_points_features_aa = grad_points_features_a[batch_idx]; + + const int64_t grid_size_x = grid_sizes_aa[2]; + const int64_t grid_size_y = grid_sizes_aa[1]; + const int64_t grid_size_z = grid_sizes_aa[0]; + + // For each point + for (int64_t point_idx = 0; point_idx < P; ++point_idx) { + if (mask_aa[point_idx] == 0) { + continue; + } + auto point = points_3d_aa[point_idx]; + auto point_features = points_features_aa[point_idx]; + auto grad_point_features = grad_points_features_aa[point_idx]; + auto grad_point = grad_points_3d_aa[point_idx]; + + // Define how to (backwards) increment a location in the point cloud, + // to take gradients to the features. + // We return false if the location does not really exist, so there was + // nothing to do. + // This happens once per point for nearest, eight times for splat. + auto increment_location = + [&](int64_t x, int64_t y, int64_t z, float weight) { + if (x >= grid_size_x || y >= grid_size_y || z >= grid_size_z) { + return false; + } + if (x < 0 || y < 0 || z < 0) { + return false; + } + + for (int64_t feature_idx = 0; feature_idx < n_features; + ++feature_idx) { + // This is a forward line, for comparison + // volume_features_aa[feature_idx][z][y][x] += + // point_features[feature_idx] * weight * point_weight; + grad_point_features[feature_idx] += + grad_volume_features_aa[feature_idx][z][y][x] * weight * + point_weight; + } + return true; + }; + + if (!splat) { + long x = std::lround( + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset); + long y = std::lround( + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset); + long z = std::lround( + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset); + increment_location(x, y, z, 1); + } else { + float x = 0, y = 0, z = 0; + float rx = std::modf( + (point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset, &x); + float ry = std::modf( + (point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset, &y); + float rz = std::modf( + (point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset, &z); + auto handle_point = [&](bool up_x, bool up_y, bool up_z) { + float weight_x = (up_x ? rx : 1 - rx); + float weight_y = (up_y ? ry : 1 - ry); + float weight_z = (up_z ? rz : 1 - rz); + float weight = weight_x * weight_y * weight_z; + // For each of the eight locations, we first increment the feature + // gradient. + if (increment_location(x + up_x, y + up_y, z + up_z, weight)) { + // If the location is a real location, we also (in this splat + // case) need to update the gradient w.r.t. the point position. + // - the amount in this location is controlled by the weight. + // There are two contributions: + // (1) The point position affects how much density we added + // to the location's density, so we have a contribution + // from grad_volume_density. Specifically, + // weight * point_weight has been added to + // volume_densities_aa[z+up_z][y+up_y][x+up_x] + // + // (2) The point position affects how much of each of the + // point's features were added to the corresponding feature + // of this location, so we have a contribution from + // grad_volume_features. Specifically, for each feature_idx, + // point_features[feature_idx] * weight * point_weight + // has been added to + // volume_features_aa[feature_idx][z+up_z][y+up_y][x+up_x] + + float source_gradient = + grad_volume_densities_aa[z + up_z][y + up_y][x + up_x]; + for (int64_t feature_idx = 0; feature_idx < n_features; + ++feature_idx) { + source_gradient += point_features[feature_idx] * + grad_volume_features_aa[feature_idx][z + up_z][y + up_y] + [x + up_x]; + } + grad_point[0] += source_gradient * (up_x ? 1 : -1) * weight_y * + weight_z * 0.5 * (grid_size_x - scale_offset) * point_weight; + grad_point[1] += source_gradient * (up_y ? 1 : -1) * weight_x * + weight_z * 0.5 * (grid_size_y - scale_offset) * point_weight; + grad_point[2] += source_gradient * (up_z ? 1 : -1) * weight_x * + weight_y * 0.5 * (grid_size_z - scale_offset) * point_weight; + } + }; + EightDirections(handle_point); + } + } + } +} diff --git a/pytorch3d/ops/points_to_volumes.py b/pytorch3d/ops/points_to_volumes.py index 59f43fe1..2b80a0fb 100644 --- a/pytorch3d/ops/points_to_volumes.py +++ b/pytorch3d/ops/points_to_volumes.py @@ -7,12 +7,186 @@ from typing import TYPE_CHECKING, Optional, Tuple import torch +from pytorch3d import _C +from torch.autograd import Function +from torch.autograd.function import once_differentiable if TYPE_CHECKING: from ..structures import Pointclouds, Volumes +class _points_to_volumes_function(Function): + """ + For each point in a pointcloud, add point_weight to the + corresponding volume density and point_weight times its features + to the corresponding volume features. + + This function does not require any contiguity internally and therefore + doesn't need to make copies of its inputs, which is useful when GPU memory + is at a premium. (An implementation requiring contiguous inputs might be faster + though). The volumes are modified in place. + + This function is differentiable with respect to + points_features, volume_densities and volume_features. + If splat is True then it is also differentiable with respect to + points_3d. + + It may be useful to think about this function as a sort of opposite to + torch.nn.functional.grid_sample with 5D inputs. + + Args: + points_3d: Batch of 3D point cloud coordinates of shape + `(minibatch, N, 3)` where N is the number of points + in each point cloud. Coordinates have to be specified in the + local volume coordinates (ranging in [-1, 1]). + points_features: Features of shape `(minibatch, N, feature_dim)` + corresponding to the points of the input point cloud `points_3d`. + volume_features: Batch of input feature volumes + of shape `(minibatch, feature_dim, D, H, W)` + volume_densities: Batch of input feature volume densities + of shape `(minibatch, 1, D, H, W)`. Each voxel should + contain a non-negative number corresponding to its + opaqueness (the higher, the less transparent). + + grid_sizes: `LongTensor` of shape (minibatch, 3) representing the + spatial resolutions of each of the the non-flattened `volumes` + tensors. Note that the following has to hold: + `torch.prod(grid_sizes, dim=1)==N_voxels`. + + point_weight: A scalar controlling how much weight a single point has. + + mask: A binary mask of shape `(minibatch, N)` determining + which 3D points are going to be converted to the resulting + volume. Set to `None` if all points are valid. + + align_corners: as for grid_sample. + + splat: if true, trilinear interpolation. If false all the weight goes in + the nearest voxel. + + Returns: + volume_densities and volume_features, which have been modified in place. + """ + + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + ctx, + points_3d: torch.Tensor, + points_features: torch.Tensor, + volume_densities: torch.Tensor, + volume_features: torch.Tensor, + grid_sizes: torch.LongTensor, + point_weight: float, + mask: torch.Tensor, + align_corners: bool, + splat: bool, + ): + + ctx.mark_dirty(volume_densities, volume_features) + + N, P, D = points_3d.shape + if D != 3: + raise ValueError("points_3d must be 3D") + if points_3d.dtype != torch.float32: + raise ValueError("points_3d must be float32") + if points_features.dtype != torch.float32: + raise ValueError("points_features must be float32") + N1, P1, C = points_features.shape + if N1 != N or P1 != P: + raise ValueError("Bad points_features shape") + if volume_densities.dtype != torch.float32: + raise ValueError("volume_densities must be float32") + N2, one, D, H, W = volume_densities.shape + if N2 != N or one != 1: + raise ValueError("Bad volume_densities shape") + if volume_features.dtype != torch.float32: + raise ValueError("volume_features must be float32") + N3, C1, D1, H1, W1 = volume_features.shape + if N3 != N or C1 != C or D1 != D or H1 != H or W1 != W: + raise ValueError("Bad volume_features shape") + if grid_sizes.dtype != torch.int64: + raise ValueError("grid_sizes must be int64") + N4, D1 = grid_sizes.shape + if N4 != N or D1 != 3: + raise ValueError("Bad grid_sizes.shape") + if mask.dtype != torch.float32: + raise ValueError("mask must be float32") + N5, P2 = mask.shape + if N5 != N or P2 != P: + raise ValueError("Bad mask shape") + + # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`. + _C.points_to_volumes_forward( + points_3d, + points_features, + volume_densities, + volume_features, + grid_sizes, + mask, + point_weight, + align_corners, + splat, + ) + if splat: + ctx.save_for_backward(points_3d, points_features, grid_sizes, mask) + else: + ctx.save_for_backward(points_3d, grid_sizes, mask) + ctx.point_weight = point_weight + ctx.splat = splat + ctx.align_corners = align_corners + return volume_densities, volume_features + + @staticmethod + @once_differentiable + def backward(ctx, grad_volume_densities, grad_volume_features): + splat = ctx.splat + N, C = grad_volume_features.shape[:2] + if splat: + points_3d, points_features, grid_sizes, mask = ctx.saved_tensors + P = points_3d.shape[1] + grad_points_3d = torch.zeros_like(points_3d) + else: + points_3d, grid_sizes, mask = ctx.saved_tensors + P = points_3d.shape[1] + ones = points_3d.new_zeros(1, 1, 1) + # There is no gradient. Just need something to let its accessors exist. + grad_points_3d = ones.expand_as(points_3d) + # points_features not needed. Just need something to let its accessors exist. + points_features = ones.expand(N, P, C) + grad_points_features = points_3d.new_zeros(N, P, C) + _C.points_to_volumes_backward( + points_3d, + points_features, + grid_sizes, + mask, + ctx.point_weight, + ctx.align_corners, + splat, + grad_volume_densities, + grad_volume_features, + grad_points_3d, + grad_points_features, + ) + + return ( + (grad_points_3d if splat else None), + grad_points_features, + grad_volume_densities, + grad_volume_features, + None, + None, + None, + None, + None, + ) + + +# pyre-fixme[16]: `_points_to_volumes_function` has no attribute `apply`. +_points_to_volumes = _points_to_volumes_function.apply + + def add_pointclouds_to_volumes( pointclouds: "Pointclouds", initial_volumes: "Volumes", diff --git a/tests/test_points_to_volumes.py b/tests/test_points_to_volumes.py index 9a17d77b..2ca4352a 100644 --- a/tests/test_points_to_volumes.py +++ b/tests/test_points_to_volumes.py @@ -5,12 +5,14 @@ # LICENSE file in the root directory of this source tree. import unittest +from functools import partial from typing import Tuple import numpy as np import torch from common_testing import TestCaseMixin from pytorch3d.ops import add_pointclouds_to_volumes +from pytorch3d.ops.points_to_volumes import _points_to_volumes from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.pointclouds import Pointclouds @@ -395,3 +397,140 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase): # check that all per-slice avg errors vanish self.assertClose(clr_diff, torch.zeros_like(clr_diff), atol=1e-2) + + +class TestRawFunction(TestCaseMixin, unittest.TestCase): + """ + Testing the _C.points_to_volumes function through its wrapper + _points_to_volumes. + """ + + def setUp(self) -> None: + torch.manual_seed(42) + + def test_grad_corners_splat_cpu(self): + self.do_gradcheck(torch.device("cpu"), True, True) + + def test_grad_corners_round_cpu(self): + self.do_gradcheck(torch.device("cpu"), False, True) + + def test_grad_splat_cpu(self): + self.do_gradcheck(torch.device("cpu"), True, False) + + def test_grad_round_cpu(self): + self.do_gradcheck(torch.device("cpu"), False, False) + + def do_gradcheck(self, device, splat: bool, align_corners: bool): + """ + Use gradcheck to verify the gradient of _points_to_volumes + with random input. + """ + N, C, D, H, W, P = 2, 4, 5, 6, 7, 5 + points_3d = ( + torch.rand((N, P, 3), device=device, dtype=torch.float64) * 0.8 + 0.1 + ) + points_features = torch.rand((N, P, C), device=device, dtype=torch.float64) + volume_densities = torch.zeros((N, 1, D, H, W), device=device) + volume_features = torch.zeros((N, C, D, H, W), device=device) + volume_densities_scale = torch.rand_like(volume_densities) + volume_features_scale = torch.rand_like(volume_features) + grid_sizes = torch.tensor([D, H, W], dtype=torch.int64, device=device).expand( + N, 3 + ) + mask = torch.ones((N, P), device=device) + mask[:, 0] = 0 + align_corners = False + + def f(points_3d_, points_features_): + (volume_densities_, volume_features_) = _points_to_volumes( + points_3d_.to(torch.float32), + points_features_.to(torch.float32), + volume_densities.clone(), + volume_features.clone(), + grid_sizes, + 2.0, + mask, + align_corners, + splat, + ) + density = (volume_densities_ * volume_densities_scale).sum() + features = (volume_features_ * volume_features_scale).sum() + return density, features + + base = f(points_3d.clone(), points_features.clone()) + self.assertGreater(base[0], 0) + self.assertGreater(base[1], 0) + + points_features.requires_grad = True + if splat: + points_3d.requires_grad = True + torch.autograd.gradcheck( + f, + (points_3d, points_features), + check_undefined_grad=False, + eps=2e-4, + atol=0.01, + ) + else: + torch.autograd.gradcheck( + partial(f, points_3d), + points_features, + check_undefined_grad=False, + eps=2e-3, + atol=0.001, + ) + + def test_single_corners_round_cpu(self): + self.single_point(torch.device("cpu"), False, True) + + def test_single_corners_splat_cpu(self): + self.single_point(torch.device("cpu"), True, True) + + def test_single_round_cpu(self): + self.single_point(torch.device("cpu"), False, False) + + def test_single_splat_cpu(self): + self.single_point(torch.device("cpu"), True, False) + + def single_point(self, device, splat: bool, align_corners: bool): + """ + Check the outcome of _points_to_volumes where a single point + exists which lines up with a single voxel. + """ + D, H, W = (6, 6, 11) if align_corners else (5, 5, 10) + N, C, P = 1, 1, 1 + if align_corners: + points_3d = torch.tensor([[[-0.2, 0.2, -0.2]]], device=device) + else: + points_3d = torch.tensor([[[-0.3, 0.4, -0.4]]], device=device) + points_features = torch.zeros((N, P, C), device=device) + volume_densities = torch.zeros((N, 1, D, H, W), device=device) + volume_densities_expected = torch.zeros((N, 1, D, H, W), device=device) + volume_features = torch.zeros((N, C, D, H, W), device=device) + grid_sizes = torch.tensor([D, H, W], dtype=torch.int64, device=device).expand( + N, 3 + ) + mask = torch.ones((N, P), device=device) + point_weight = 19.0 + + volume_densities_, volume_features_ = _points_to_volumes( + points_3d, + points_features, + volume_densities, + volume_features, + grid_sizes, + point_weight, + mask, + align_corners, + splat, + ) + + self.assertIs(volume_densities, volume_densities_) + self.assertIs(volume_features, volume_features_) + + if align_corners: + volume_densities_expected[0, 0, 2, 3, 4] = point_weight + else: + volume_densities_expected[0, 0, 1, 3, 3] = point_weight + + self.assertClose(volume_densities, volume_densities_expected)