mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
CPU function for points2vols
Summary: Single C++ function for the core of points2vols, not used anywhere yet. Added ability to control align_corners and the weight of each point, which may be useful later. Reviewed By: nikhilaravi Differential Revision: D29548607 fbshipit-source-id: a5cda7ec2c14836624e7dfe744c4bbb3f3d3dfe2
This commit is contained in:
parent
c7c6deab86
commit
0dfc6e0eb8
@ -25,6 +25,7 @@
|
||||
#include "mesh_normal_consistency/mesh_normal_consistency.h"
|
||||
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
|
||||
#include "point_mesh/point_mesh_cuda.h"
|
||||
#include "points_to_volumes/points_to_volumes.h"
|
||||
#include "rasterize_meshes/rasterize_meshes.h"
|
||||
#include "rasterize_points/rasterize_points.h"
|
||||
#include "sample_farthest_points/sample_farthest_points.h"
|
||||
@ -47,6 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
|
||||
m.def("gather_scatter", &GatherScatter);
|
||||
m.def("points_to_volumes_forward", PointsToVolumesForward);
|
||||
m.def("points_to_volumes_backward", PointsToVolumesBackward);
|
||||
m.def("rasterize_points", &RasterizePoints);
|
||||
m.def("rasterize_points_backward", &RasterizePointsBackward);
|
||||
m.def("rasterize_meshes_backward", &RasterizeMeshesBackward);
|
||||
|
135
pytorch3d/csrc/points_to_volumes/points_to_volumes.h
Normal file
135
pytorch3d/csrc/points_to_volumes/points_to_volumes.h
Normal file
@ -0,0 +1,135 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
/*
|
||||
volume_features and volume_densities are modified in place.
|
||||
|
||||
Args:
|
||||
points_3d: Batch of 3D point cloud coordinates of shape
|
||||
`(minibatch, N, 3)` where N is the number of points
|
||||
in each point cloud. Coordinates have to be specified in the
|
||||
local volume coordinates (ranging in [-1, 1]).
|
||||
points_features: Features of shape `(minibatch, N, feature_dim)`
|
||||
corresponding to the points of the input point cloud `points_3d`.
|
||||
volume_features: Batch of input feature volumes
|
||||
of shape `(minibatch, feature_dim, D, H, W)`
|
||||
volume_densities: Batch of input feature volume densities
|
||||
of shape `(minibatch, 1, D, H, W)`. Each voxel should
|
||||
contain a non-negative number corresponding to its
|
||||
opaqueness (the higher, the less transparent).
|
||||
|
||||
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
|
||||
spatial resolutions of each of the the non-flattened `volumes`
|
||||
tensors. Note that the following has to hold:
|
||||
`torch.prod(grid_sizes, dim=1)==N_voxels`.
|
||||
|
||||
point_weight: A scalar controlling how much weight a single point has.
|
||||
|
||||
mask: A binary mask of shape `(minibatch, N)` determining
|
||||
which 3D points are going to be converted to the resulting
|
||||
volume. Set to `None` if all points are valid.
|
||||
|
||||
align_corners: as for grid_sample.
|
||||
|
||||
splat: if true, trilinear interpolation. If false all the weight goes in
|
||||
the nearest voxel.
|
||||
*/
|
||||
|
||||
void PointsToVolumesForwardCpu(
|
||||
const torch::Tensor& points_3d,
|
||||
const torch::Tensor& points_features,
|
||||
const torch::Tensor& volume_densities,
|
||||
const torch::Tensor& volume_features,
|
||||
const torch::Tensor& grid_sizes,
|
||||
const torch::Tensor& mask,
|
||||
float point_weight,
|
||||
bool align_corners,
|
||||
bool splat);
|
||||
|
||||
inline void PointsToVolumesForward(
|
||||
const torch::Tensor& points_3d,
|
||||
const torch::Tensor& points_features,
|
||||
const torch::Tensor& volume_densities,
|
||||
const torch::Tensor& volume_features,
|
||||
const torch::Tensor& grid_sizes,
|
||||
const torch::Tensor& mask,
|
||||
float point_weight,
|
||||
bool align_corners,
|
||||
bool splat) {
|
||||
if (points_3d.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
AT_ERROR("CUDA not implemented yet");
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
PointsToVolumesForwardCpu(
|
||||
points_3d,
|
||||
points_features,
|
||||
volume_densities,
|
||||
volume_features,
|
||||
grid_sizes,
|
||||
mask,
|
||||
point_weight,
|
||||
align_corners,
|
||||
splat);
|
||||
}
|
||||
|
||||
// grad_points_3d and grad_points_features are modified in place.
|
||||
|
||||
void PointsToVolumesBackwardCpu(
|
||||
const torch::Tensor& points_3d,
|
||||
const torch::Tensor& points_features,
|
||||
const torch::Tensor& grid_sizes,
|
||||
const torch::Tensor& mask,
|
||||
float point_weight,
|
||||
bool align_corners,
|
||||
bool splat,
|
||||
const torch::Tensor& grad_volume_densities,
|
||||
const torch::Tensor& grad_volume_features,
|
||||
const torch::Tensor& grad_points_3d,
|
||||
const torch::Tensor& grad_points_features);
|
||||
|
||||
inline void PointsToVolumesBackward(
|
||||
const torch::Tensor& points_3d,
|
||||
const torch::Tensor& points_features,
|
||||
const torch::Tensor& grid_sizes,
|
||||
const torch::Tensor& mask,
|
||||
float point_weight,
|
||||
bool align_corners,
|
||||
bool splat,
|
||||
const torch::Tensor& grad_volume_densities,
|
||||
const torch::Tensor& grad_volume_features,
|
||||
const torch::Tensor& grad_points_3d,
|
||||
const torch::Tensor& grad_points_features) {
|
||||
if (points_3d.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
AT_ERROR("CUDA not implemented yet");
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
PointsToVolumesBackwardCpu(
|
||||
points_3d,
|
||||
points_features,
|
||||
grid_sizes,
|
||||
mask,
|
||||
point_weight,
|
||||
align_corners,
|
||||
splat,
|
||||
grad_volume_densities,
|
||||
grad_volume_features,
|
||||
grad_points_3d,
|
||||
grad_points_features);
|
||||
}
|
316
pytorch3d/csrc/points_to_volumes/points_to_volumes_cpu.cpp
Normal file
316
pytorch3d/csrc/points_to_volumes/points_to_volumes_cpu.cpp
Normal file
@ -0,0 +1,316 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to
|
||||
// points px in [-1, 1]. There are two ways to do this.
|
||||
|
||||
// If align_corners=True, px=-1 is the exact location 0 and px=1 is the exact
|
||||
// location grid_size_x - 1.
|
||||
// So the location of px is {(px + 1) * 0.5} * (grid_size_x - 1).
|
||||
// Note that if you generate random points within the bounds you are less likely
|
||||
// to hit the edge locations than other locations.
|
||||
// This can be thought of as saying "location i" means a specific point.
|
||||
|
||||
// If align_corners=False, px=-1 is half way between the exact location 0 and
|
||||
// the non-existent location -1, i.e. location -0.5.
|
||||
// Similarly px=1 is is half way between the exact location grid_size_x-1 and
|
||||
// the non-existent location grid_size, i.e. the location grid_size_x - 0.5.
|
||||
// So the location of px is ({(px + 1) * 0.5} * grid_size_x) - 0.5.
|
||||
// Note that if you generate random points within the bounds you are equally
|
||||
// likely to hit any location.
|
||||
// This can be thought of as saying "location i" means the whole box from
|
||||
// (i-0.5) to (i+0.5)
|
||||
|
||||
// EightDirections(t) runs t(a,b,c) for every combination of boolean a, b, c.
|
||||
template <class T>
|
||||
static void EightDirections(T&& t) {
|
||||
t(false, false, false);
|
||||
t(false, false, true);
|
||||
t(false, true, false);
|
||||
t(false, true, true);
|
||||
t(true, false, false);
|
||||
t(true, false, true);
|
||||
t(true, true, false);
|
||||
t(true, true, true);
|
||||
}
|
||||
|
||||
void PointsToVolumesForwardCpu(
|
||||
const torch::Tensor& points_3d,
|
||||
const torch::Tensor& points_features,
|
||||
const torch::Tensor& volume_densities,
|
||||
const torch::Tensor& volume_features,
|
||||
const torch::Tensor& grid_sizes,
|
||||
const torch::Tensor& mask,
|
||||
const float point_weight,
|
||||
const bool align_corners,
|
||||
const bool splat) {
|
||||
const int64_t batch_size = points_3d.size(0);
|
||||
const int64_t P = points_3d.size(1);
|
||||
const int64_t n_features = points_features.size(2);
|
||||
|
||||
// We unify the formula for the location of px in the comment above as
|
||||
// ({(px + 1) * 0.5} * (grid_size_x-scale_offset)) - offset.
|
||||
const int scale_offset = align_corners ? 1 : 0;
|
||||
const float offset = align_corners ? 0 : 0.5;
|
||||
|
||||
auto points_3d_a = points_3d.accessor<float, 3>();
|
||||
auto points_features_a = points_features.accessor<float, 3>();
|
||||
auto volume_densities_a = volume_densities.accessor<float, 5>();
|
||||
auto volume_features_a = volume_features.accessor<float, 5>();
|
||||
auto grid_sizes_a = grid_sizes.accessor<int64_t, 2>();
|
||||
auto mask_a = mask.accessor<float, 2>();
|
||||
|
||||
// For each batch element
|
||||
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||
auto points_3d_aa = points_3d_a[batch_idx];
|
||||
auto points_features_aa = points_features_a[batch_idx];
|
||||
auto volume_densities_aa = volume_densities_a[batch_idx][0];
|
||||
auto volume_features_aa = volume_features_a[batch_idx];
|
||||
auto grid_sizes_aa = grid_sizes_a[batch_idx];
|
||||
auto mask_aa = mask_a[batch_idx];
|
||||
|
||||
const int64_t grid_size_x = grid_sizes_aa[2];
|
||||
const int64_t grid_size_y = grid_sizes_aa[1];
|
||||
const int64_t grid_size_z = grid_sizes_aa[0];
|
||||
|
||||
// For each point
|
||||
for (int64_t point_idx = 0; point_idx < P; ++point_idx) {
|
||||
// Ignore point if mask is 0
|
||||
if (mask_aa[point_idx] == 0) {
|
||||
continue;
|
||||
}
|
||||
auto point = points_3d_aa[point_idx];
|
||||
auto point_features = points_features_aa[point_idx];
|
||||
|
||||
// Define how to increment a location in the volume by an amount. The need
|
||||
// for this depends on the interpolation method:
|
||||
// once per point for nearest, eight times for splat.
|
||||
auto increment_location =
|
||||
[&](int64_t x, int64_t y, int64_t z, float weight) {
|
||||
if (x >= grid_size_x || y >= grid_size_y || z >= grid_size_z) {
|
||||
return;
|
||||
}
|
||||
if (x < 0 || y < 0 || z < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
volume_densities_aa[z][y][x] += weight * point_weight;
|
||||
|
||||
for (int64_t feature_idx = 0; feature_idx < n_features;
|
||||
++feature_idx) {
|
||||
volume_features_aa[feature_idx][z][y][x] +=
|
||||
point_features[feature_idx] * weight * point_weight;
|
||||
}
|
||||
};
|
||||
|
||||
if (!splat) {
|
||||
// Increment the location nearest the point.
|
||||
long x = std::lround(
|
||||
(point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset);
|
||||
long y = std::lround(
|
||||
(point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset);
|
||||
long z = std::lround(
|
||||
(point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset);
|
||||
increment_location(x, y, z, 1);
|
||||
} else {
|
||||
// There are 8 locations around the point which we need to worry about.
|
||||
// Their coordinates are (x or x+1, y or y+1, z or z+1).
|
||||
// rx is a number between 0 and 1 for the proportion in the x direction:
|
||||
// rx==0 means weight all on the lower bound, x, rx=1-eps means most
|
||||
// weight on x+1. Ditto for ry and yz.
|
||||
float x = 0, y = 0, z = 0;
|
||||
float rx = std::modf(
|
||||
(point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset, &x);
|
||||
float ry = std::modf(
|
||||
(point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset, &y);
|
||||
float rz = std::modf(
|
||||
(point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset, &z);
|
||||
// Define how to fractionally increment one of the 8 locations around
|
||||
// the point.
|
||||
auto handle_point = [&](bool up_x, bool up_y, bool up_z) {
|
||||
float weight = (up_x ? rx : 1 - rx) * (up_y ? ry : 1 - ry) *
|
||||
(up_z ? rz : 1 - rz);
|
||||
increment_location(x + up_x, y + up_y, z + up_z, weight);
|
||||
};
|
||||
// and do so.
|
||||
EightDirections(handle_point);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// With nearest, the only smooth dependence is that volume features
|
||||
// depend on points features.
|
||||
//
|
||||
// With splat, the dependencies are as follows, with gradients passing
|
||||
// in the opposite direction.
|
||||
//
|
||||
// points_3d points_features
|
||||
// │ │ │
|
||||
// │ │ │
|
||||
// │ └───────────┐ │
|
||||
// │ │ │
|
||||
// │ │ │
|
||||
// ▼ ▼ ▼
|
||||
// volume_densities volume_features
|
||||
|
||||
// It is also the case that the input volume_densities and
|
||||
// volume_features affect the corresponding outputs (they are
|
||||
// modified in place).
|
||||
// But the forward pass just increments these by a value which
|
||||
// does not depend on them. So our autograd backwards pass needs
|
||||
// to copy the gradient for each of those outputs to the
|
||||
// corresponding input. We just do that in the Python layer.
|
||||
|
||||
void PointsToVolumesBackwardCpu(
|
||||
const torch::Tensor& points_3d,
|
||||
const torch::Tensor& points_features,
|
||||
const torch::Tensor& grid_sizes,
|
||||
const torch::Tensor& mask,
|
||||
const float point_weight,
|
||||
const bool align_corners,
|
||||
const bool splat,
|
||||
const torch::Tensor& grad_volume_densities,
|
||||
const torch::Tensor& grad_volume_features,
|
||||
const torch::Tensor& grad_points_3d,
|
||||
const torch::Tensor& grad_points_features) {
|
||||
const int64_t batch_size = points_3d.size(0);
|
||||
const int64_t P = points_3d.size(1);
|
||||
const int64_t n_features = grad_points_features.size(2);
|
||||
const int scale_offset = align_corners ? 1 : 0;
|
||||
const float offset = align_corners ? 0 : 0.5;
|
||||
|
||||
auto points_3d_a = points_3d.accessor<float, 3>();
|
||||
auto points_features_a = points_features.accessor<float, 3>();
|
||||
auto grid_sizes_a = grid_sizes.accessor<int64_t, 2>();
|
||||
auto mask_a = mask.accessor<float, 2>();
|
||||
auto grad_volume_densities_a = grad_volume_densities.accessor<float, 5>();
|
||||
auto grad_volume_features_a = grad_volume_features.accessor<float, 5>();
|
||||
auto grad_points_3d_a = grad_points_3d.accessor<float, 3>();
|
||||
auto grad_points_features_a = grad_points_features.accessor<float, 3>();
|
||||
|
||||
// For each batch element
|
||||
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||
auto points_3d_aa = points_3d_a[batch_idx];
|
||||
auto points_features_aa = points_features_a[batch_idx];
|
||||
auto grid_sizes_aa = grid_sizes_a[batch_idx];
|
||||
auto mask_aa = mask_a[batch_idx];
|
||||
auto grad_volume_densities_aa = grad_volume_densities_a[batch_idx][0];
|
||||
auto grad_volume_features_aa = grad_volume_features_a[batch_idx];
|
||||
auto grad_points_3d_aa = grad_points_3d_a[batch_idx];
|
||||
auto grad_points_features_aa = grad_points_features_a[batch_idx];
|
||||
|
||||
const int64_t grid_size_x = grid_sizes_aa[2];
|
||||
const int64_t grid_size_y = grid_sizes_aa[1];
|
||||
const int64_t grid_size_z = grid_sizes_aa[0];
|
||||
|
||||
// For each point
|
||||
for (int64_t point_idx = 0; point_idx < P; ++point_idx) {
|
||||
if (mask_aa[point_idx] == 0) {
|
||||
continue;
|
||||
}
|
||||
auto point = points_3d_aa[point_idx];
|
||||
auto point_features = points_features_aa[point_idx];
|
||||
auto grad_point_features = grad_points_features_aa[point_idx];
|
||||
auto grad_point = grad_points_3d_aa[point_idx];
|
||||
|
||||
// Define how to (backwards) increment a location in the point cloud,
|
||||
// to take gradients to the features.
|
||||
// We return false if the location does not really exist, so there was
|
||||
// nothing to do.
|
||||
// This happens once per point for nearest, eight times for splat.
|
||||
auto increment_location =
|
||||
[&](int64_t x, int64_t y, int64_t z, float weight) {
|
||||
if (x >= grid_size_x || y >= grid_size_y || z >= grid_size_z) {
|
||||
return false;
|
||||
}
|
||||
if (x < 0 || y < 0 || z < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int64_t feature_idx = 0; feature_idx < n_features;
|
||||
++feature_idx) {
|
||||
// This is a forward line, for comparison
|
||||
// volume_features_aa[feature_idx][z][y][x] +=
|
||||
// point_features[feature_idx] * weight * point_weight;
|
||||
grad_point_features[feature_idx] +=
|
||||
grad_volume_features_aa[feature_idx][z][y][x] * weight *
|
||||
point_weight;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
if (!splat) {
|
||||
long x = std::lround(
|
||||
(point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset);
|
||||
long y = std::lround(
|
||||
(point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset);
|
||||
long z = std::lround(
|
||||
(point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset);
|
||||
increment_location(x, y, z, 1);
|
||||
} else {
|
||||
float x = 0, y = 0, z = 0;
|
||||
float rx = std::modf(
|
||||
(point[0] + 1) * 0.5 * (grid_size_x - scale_offset) - offset, &x);
|
||||
float ry = std::modf(
|
||||
(point[1] + 1) * 0.5 * (grid_size_y - scale_offset) - offset, &y);
|
||||
float rz = std::modf(
|
||||
(point[2] + 1) * 0.5 * (grid_size_z - scale_offset) - offset, &z);
|
||||
auto handle_point = [&](bool up_x, bool up_y, bool up_z) {
|
||||
float weight_x = (up_x ? rx : 1 - rx);
|
||||
float weight_y = (up_y ? ry : 1 - ry);
|
||||
float weight_z = (up_z ? rz : 1 - rz);
|
||||
float weight = weight_x * weight_y * weight_z;
|
||||
// For each of the eight locations, we first increment the feature
|
||||
// gradient.
|
||||
if (increment_location(x + up_x, y + up_y, z + up_z, weight)) {
|
||||
// If the location is a real location, we also (in this splat
|
||||
// case) need to update the gradient w.r.t. the point position.
|
||||
// - the amount in this location is controlled by the weight.
|
||||
// There are two contributions:
|
||||
// (1) The point position affects how much density we added
|
||||
// to the location's density, so we have a contribution
|
||||
// from grad_volume_density. Specifically,
|
||||
// weight * point_weight has been added to
|
||||
// volume_densities_aa[z+up_z][y+up_y][x+up_x]
|
||||
//
|
||||
// (2) The point position affects how much of each of the
|
||||
// point's features were added to the corresponding feature
|
||||
// of this location, so we have a contribution from
|
||||
// grad_volume_features. Specifically, for each feature_idx,
|
||||
// point_features[feature_idx] * weight * point_weight
|
||||
// has been added to
|
||||
// volume_features_aa[feature_idx][z+up_z][y+up_y][x+up_x]
|
||||
|
||||
float source_gradient =
|
||||
grad_volume_densities_aa[z + up_z][y + up_y][x + up_x];
|
||||
for (int64_t feature_idx = 0; feature_idx < n_features;
|
||||
++feature_idx) {
|
||||
source_gradient += point_features[feature_idx] *
|
||||
grad_volume_features_aa[feature_idx][z + up_z][y + up_y]
|
||||
[x + up_x];
|
||||
}
|
||||
grad_point[0] += source_gradient * (up_x ? 1 : -1) * weight_y *
|
||||
weight_z * 0.5 * (grid_size_x - scale_offset) * point_weight;
|
||||
grad_point[1] += source_gradient * (up_y ? 1 : -1) * weight_x *
|
||||
weight_z * 0.5 * (grid_size_y - scale_offset) * point_weight;
|
||||
grad_point[2] += source_gradient * (up_z ? 1 : -1) * weight_x *
|
||||
weight_y * 0.5 * (grid_size_z - scale_offset) * point_weight;
|
||||
}
|
||||
};
|
||||
EightDirections(handle_point);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -7,12 +7,186 @@
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..structures import Pointclouds, Volumes
|
||||
|
||||
|
||||
class _points_to_volumes_function(Function):
|
||||
"""
|
||||
For each point in a pointcloud, add point_weight to the
|
||||
corresponding volume density and point_weight times its features
|
||||
to the corresponding volume features.
|
||||
|
||||
This function does not require any contiguity internally and therefore
|
||||
doesn't need to make copies of its inputs, which is useful when GPU memory
|
||||
is at a premium. (An implementation requiring contiguous inputs might be faster
|
||||
though). The volumes are modified in place.
|
||||
|
||||
This function is differentiable with respect to
|
||||
points_features, volume_densities and volume_features.
|
||||
If splat is True then it is also differentiable with respect to
|
||||
points_3d.
|
||||
|
||||
It may be useful to think about this function as a sort of opposite to
|
||||
torch.nn.functional.grid_sample with 5D inputs.
|
||||
|
||||
Args:
|
||||
points_3d: Batch of 3D point cloud coordinates of shape
|
||||
`(minibatch, N, 3)` where N is the number of points
|
||||
in each point cloud. Coordinates have to be specified in the
|
||||
local volume coordinates (ranging in [-1, 1]).
|
||||
points_features: Features of shape `(minibatch, N, feature_dim)`
|
||||
corresponding to the points of the input point cloud `points_3d`.
|
||||
volume_features: Batch of input feature volumes
|
||||
of shape `(minibatch, feature_dim, D, H, W)`
|
||||
volume_densities: Batch of input feature volume densities
|
||||
of shape `(minibatch, 1, D, H, W)`. Each voxel should
|
||||
contain a non-negative number corresponding to its
|
||||
opaqueness (the higher, the less transparent).
|
||||
|
||||
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
|
||||
spatial resolutions of each of the the non-flattened `volumes`
|
||||
tensors. Note that the following has to hold:
|
||||
`torch.prod(grid_sizes, dim=1)==N_voxels`.
|
||||
|
||||
point_weight: A scalar controlling how much weight a single point has.
|
||||
|
||||
mask: A binary mask of shape `(minibatch, N)` determining
|
||||
which 3D points are going to be converted to the resulting
|
||||
volume. Set to `None` if all points are valid.
|
||||
|
||||
align_corners: as for grid_sample.
|
||||
|
||||
splat: if true, trilinear interpolation. If false all the weight goes in
|
||||
the nearest voxel.
|
||||
|
||||
Returns:
|
||||
volume_densities and volume_features, which have been modified in place.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
|
||||
def forward(
|
||||
ctx,
|
||||
points_3d: torch.Tensor,
|
||||
points_features: torch.Tensor,
|
||||
volume_densities: torch.Tensor,
|
||||
volume_features: torch.Tensor,
|
||||
grid_sizes: torch.LongTensor,
|
||||
point_weight: float,
|
||||
mask: torch.Tensor,
|
||||
align_corners: bool,
|
||||
splat: bool,
|
||||
):
|
||||
|
||||
ctx.mark_dirty(volume_densities, volume_features)
|
||||
|
||||
N, P, D = points_3d.shape
|
||||
if D != 3:
|
||||
raise ValueError("points_3d must be 3D")
|
||||
if points_3d.dtype != torch.float32:
|
||||
raise ValueError("points_3d must be float32")
|
||||
if points_features.dtype != torch.float32:
|
||||
raise ValueError("points_features must be float32")
|
||||
N1, P1, C = points_features.shape
|
||||
if N1 != N or P1 != P:
|
||||
raise ValueError("Bad points_features shape")
|
||||
if volume_densities.dtype != torch.float32:
|
||||
raise ValueError("volume_densities must be float32")
|
||||
N2, one, D, H, W = volume_densities.shape
|
||||
if N2 != N or one != 1:
|
||||
raise ValueError("Bad volume_densities shape")
|
||||
if volume_features.dtype != torch.float32:
|
||||
raise ValueError("volume_features must be float32")
|
||||
N3, C1, D1, H1, W1 = volume_features.shape
|
||||
if N3 != N or C1 != C or D1 != D or H1 != H or W1 != W:
|
||||
raise ValueError("Bad volume_features shape")
|
||||
if grid_sizes.dtype != torch.int64:
|
||||
raise ValueError("grid_sizes must be int64")
|
||||
N4, D1 = grid_sizes.shape
|
||||
if N4 != N or D1 != 3:
|
||||
raise ValueError("Bad grid_sizes.shape")
|
||||
if mask.dtype != torch.float32:
|
||||
raise ValueError("mask must be float32")
|
||||
N5, P2 = mask.shape
|
||||
if N5 != N or P2 != P:
|
||||
raise ValueError("Bad mask shape")
|
||||
|
||||
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
|
||||
_C.points_to_volumes_forward(
|
||||
points_3d,
|
||||
points_features,
|
||||
volume_densities,
|
||||
volume_features,
|
||||
grid_sizes,
|
||||
mask,
|
||||
point_weight,
|
||||
align_corners,
|
||||
splat,
|
||||
)
|
||||
if splat:
|
||||
ctx.save_for_backward(points_3d, points_features, grid_sizes, mask)
|
||||
else:
|
||||
ctx.save_for_backward(points_3d, grid_sizes, mask)
|
||||
ctx.point_weight = point_weight
|
||||
ctx.splat = splat
|
||||
ctx.align_corners = align_corners
|
||||
return volume_densities, volume_features
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_volume_densities, grad_volume_features):
|
||||
splat = ctx.splat
|
||||
N, C = grad_volume_features.shape[:2]
|
||||
if splat:
|
||||
points_3d, points_features, grid_sizes, mask = ctx.saved_tensors
|
||||
P = points_3d.shape[1]
|
||||
grad_points_3d = torch.zeros_like(points_3d)
|
||||
else:
|
||||
points_3d, grid_sizes, mask = ctx.saved_tensors
|
||||
P = points_3d.shape[1]
|
||||
ones = points_3d.new_zeros(1, 1, 1)
|
||||
# There is no gradient. Just need something to let its accessors exist.
|
||||
grad_points_3d = ones.expand_as(points_3d)
|
||||
# points_features not needed. Just need something to let its accessors exist.
|
||||
points_features = ones.expand(N, P, C)
|
||||
grad_points_features = points_3d.new_zeros(N, P, C)
|
||||
_C.points_to_volumes_backward(
|
||||
points_3d,
|
||||
points_features,
|
||||
grid_sizes,
|
||||
mask,
|
||||
ctx.point_weight,
|
||||
ctx.align_corners,
|
||||
splat,
|
||||
grad_volume_densities,
|
||||
grad_volume_features,
|
||||
grad_points_3d,
|
||||
grad_points_features,
|
||||
)
|
||||
|
||||
return (
|
||||
(grad_points_3d if splat else None),
|
||||
grad_points_features,
|
||||
grad_volume_densities,
|
||||
grad_volume_features,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# pyre-fixme[16]: `_points_to_volumes_function` has no attribute `apply`.
|
||||
_points_to_volumes = _points_to_volumes_function.apply
|
||||
|
||||
|
||||
def add_pointclouds_to_volumes(
|
||||
pointclouds: "Pointclouds",
|
||||
initial_volumes: "Volumes",
|
||||
|
@ -5,12 +5,14 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops import add_pointclouds_to_volumes
|
||||
from pytorch3d.ops.points_to_volumes import _points_to_volumes
|
||||
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
@ -395,3 +397,140 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# check that all per-slice avg errors vanish
|
||||
self.assertClose(clr_diff, torch.zeros_like(clr_diff), atol=1e-2)
|
||||
|
||||
|
||||
class TestRawFunction(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
Testing the _C.points_to_volumes function through its wrapper
|
||||
_points_to_volumes.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_grad_corners_splat_cpu(self):
|
||||
self.do_gradcheck(torch.device("cpu"), True, True)
|
||||
|
||||
def test_grad_corners_round_cpu(self):
|
||||
self.do_gradcheck(torch.device("cpu"), False, True)
|
||||
|
||||
def test_grad_splat_cpu(self):
|
||||
self.do_gradcheck(torch.device("cpu"), True, False)
|
||||
|
||||
def test_grad_round_cpu(self):
|
||||
self.do_gradcheck(torch.device("cpu"), False, False)
|
||||
|
||||
def do_gradcheck(self, device, splat: bool, align_corners: bool):
|
||||
"""
|
||||
Use gradcheck to verify the gradient of _points_to_volumes
|
||||
with random input.
|
||||
"""
|
||||
N, C, D, H, W, P = 2, 4, 5, 6, 7, 5
|
||||
points_3d = (
|
||||
torch.rand((N, P, 3), device=device, dtype=torch.float64) * 0.8 + 0.1
|
||||
)
|
||||
points_features = torch.rand((N, P, C), device=device, dtype=torch.float64)
|
||||
volume_densities = torch.zeros((N, 1, D, H, W), device=device)
|
||||
volume_features = torch.zeros((N, C, D, H, W), device=device)
|
||||
volume_densities_scale = torch.rand_like(volume_densities)
|
||||
volume_features_scale = torch.rand_like(volume_features)
|
||||
grid_sizes = torch.tensor([D, H, W], dtype=torch.int64, device=device).expand(
|
||||
N, 3
|
||||
)
|
||||
mask = torch.ones((N, P), device=device)
|
||||
mask[:, 0] = 0
|
||||
align_corners = False
|
||||
|
||||
def f(points_3d_, points_features_):
|
||||
(volume_densities_, volume_features_) = _points_to_volumes(
|
||||
points_3d_.to(torch.float32),
|
||||
points_features_.to(torch.float32),
|
||||
volume_densities.clone(),
|
||||
volume_features.clone(),
|
||||
grid_sizes,
|
||||
2.0,
|
||||
mask,
|
||||
align_corners,
|
||||
splat,
|
||||
)
|
||||
density = (volume_densities_ * volume_densities_scale).sum()
|
||||
features = (volume_features_ * volume_features_scale).sum()
|
||||
return density, features
|
||||
|
||||
base = f(points_3d.clone(), points_features.clone())
|
||||
self.assertGreater(base[0], 0)
|
||||
self.assertGreater(base[1], 0)
|
||||
|
||||
points_features.requires_grad = True
|
||||
if splat:
|
||||
points_3d.requires_grad = True
|
||||
torch.autograd.gradcheck(
|
||||
f,
|
||||
(points_3d, points_features),
|
||||
check_undefined_grad=False,
|
||||
eps=2e-4,
|
||||
atol=0.01,
|
||||
)
|
||||
else:
|
||||
torch.autograd.gradcheck(
|
||||
partial(f, points_3d),
|
||||
points_features,
|
||||
check_undefined_grad=False,
|
||||
eps=2e-3,
|
||||
atol=0.001,
|
||||
)
|
||||
|
||||
def test_single_corners_round_cpu(self):
|
||||
self.single_point(torch.device("cpu"), False, True)
|
||||
|
||||
def test_single_corners_splat_cpu(self):
|
||||
self.single_point(torch.device("cpu"), True, True)
|
||||
|
||||
def test_single_round_cpu(self):
|
||||
self.single_point(torch.device("cpu"), False, False)
|
||||
|
||||
def test_single_splat_cpu(self):
|
||||
self.single_point(torch.device("cpu"), True, False)
|
||||
|
||||
def single_point(self, device, splat: bool, align_corners: bool):
|
||||
"""
|
||||
Check the outcome of _points_to_volumes where a single point
|
||||
exists which lines up with a single voxel.
|
||||
"""
|
||||
D, H, W = (6, 6, 11) if align_corners else (5, 5, 10)
|
||||
N, C, P = 1, 1, 1
|
||||
if align_corners:
|
||||
points_3d = torch.tensor([[[-0.2, 0.2, -0.2]]], device=device)
|
||||
else:
|
||||
points_3d = torch.tensor([[[-0.3, 0.4, -0.4]]], device=device)
|
||||
points_features = torch.zeros((N, P, C), device=device)
|
||||
volume_densities = torch.zeros((N, 1, D, H, W), device=device)
|
||||
volume_densities_expected = torch.zeros((N, 1, D, H, W), device=device)
|
||||
volume_features = torch.zeros((N, C, D, H, W), device=device)
|
||||
grid_sizes = torch.tensor([D, H, W], dtype=torch.int64, device=device).expand(
|
||||
N, 3
|
||||
)
|
||||
mask = torch.ones((N, P), device=device)
|
||||
point_weight = 19.0
|
||||
|
||||
volume_densities_, volume_features_ = _points_to_volumes(
|
||||
points_3d,
|
||||
points_features,
|
||||
volume_densities,
|
||||
volume_features,
|
||||
grid_sizes,
|
||||
point_weight,
|
||||
mask,
|
||||
align_corners,
|
||||
splat,
|
||||
)
|
||||
|
||||
self.assertIs(volume_densities, volume_densities_)
|
||||
self.assertIs(volume_features, volume_features_)
|
||||
|
||||
if align_corners:
|
||||
volume_densities_expected[0, 0, 2, 3, 4] = point_weight
|
||||
else:
|
||||
volume_densities_expected[0, 0, 1, 3, 3] = point_weight
|
||||
|
||||
self.assertClose(volume_densities, volume_densities_expected)
|
||||
|
Loading…
x
Reference in New Issue
Block a user