CPU function for points2vols

Summary: Single C++ function for the core of points2vols, not used anywhere yet. Added ability to control align_corners and the weight of each point, which may be useful later.

Reviewed By: nikhilaravi

Differential Revision: D29548607

fbshipit-source-id: a5cda7ec2c14836624e7dfe744c4bbb3f3d3dfe2
This commit is contained in:
Jeremy Reizenstein
2021-10-01 11:57:07 -07:00
committed by Facebook GitHub Bot
parent c7c6deab86
commit 0dfc6e0eb8
5 changed files with 767 additions and 0 deletions

View File

@@ -7,12 +7,186 @@
from typing import TYPE_CHECKING, Optional, Tuple
import torch
from pytorch3d import _C
from torch.autograd import Function
from torch.autograd.function import once_differentiable
if TYPE_CHECKING:
from ..structures import Pointclouds, Volumes
class _points_to_volumes_function(Function):
"""
For each point in a pointcloud, add point_weight to the
corresponding volume density and point_weight times its features
to the corresponding volume features.
This function does not require any contiguity internally and therefore
doesn't need to make copies of its inputs, which is useful when GPU memory
is at a premium. (An implementation requiring contiguous inputs might be faster
though). The volumes are modified in place.
This function is differentiable with respect to
points_features, volume_densities and volume_features.
If splat is True then it is also differentiable with respect to
points_3d.
It may be useful to think about this function as a sort of opposite to
torch.nn.functional.grid_sample with 5D inputs.
Args:
points_3d: Batch of 3D point cloud coordinates of shape
`(minibatch, N, 3)` where N is the number of points
in each point cloud. Coordinates have to be specified in the
local volume coordinates (ranging in [-1, 1]).
points_features: Features of shape `(minibatch, N, feature_dim)`
corresponding to the points of the input point cloud `points_3d`.
volume_features: Batch of input feature volumes
of shape `(minibatch, feature_dim, D, H, W)`
volume_densities: Batch of input feature volume densities
of shape `(minibatch, 1, D, H, W)`. Each voxel should
contain a non-negative number corresponding to its
opaqueness (the higher, the less transparent).
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
spatial resolutions of each of the the non-flattened `volumes`
tensors. Note that the following has to hold:
`torch.prod(grid_sizes, dim=1)==N_voxels`.
point_weight: A scalar controlling how much weight a single point has.
mask: A binary mask of shape `(minibatch, N)` determining
which 3D points are going to be converted to the resulting
volume. Set to `None` if all points are valid.
align_corners: as for grid_sample.
splat: if true, trilinear interpolation. If false all the weight goes in
the nearest voxel.
Returns:
volume_densities and volume_features, which have been modified in place.
"""
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
ctx,
points_3d: torch.Tensor,
points_features: torch.Tensor,
volume_densities: torch.Tensor,
volume_features: torch.Tensor,
grid_sizes: torch.LongTensor,
point_weight: float,
mask: torch.Tensor,
align_corners: bool,
splat: bool,
):
ctx.mark_dirty(volume_densities, volume_features)
N, P, D = points_3d.shape
if D != 3:
raise ValueError("points_3d must be 3D")
if points_3d.dtype != torch.float32:
raise ValueError("points_3d must be float32")
if points_features.dtype != torch.float32:
raise ValueError("points_features must be float32")
N1, P1, C = points_features.shape
if N1 != N or P1 != P:
raise ValueError("Bad points_features shape")
if volume_densities.dtype != torch.float32:
raise ValueError("volume_densities must be float32")
N2, one, D, H, W = volume_densities.shape
if N2 != N or one != 1:
raise ValueError("Bad volume_densities shape")
if volume_features.dtype != torch.float32:
raise ValueError("volume_features must be float32")
N3, C1, D1, H1, W1 = volume_features.shape
if N3 != N or C1 != C or D1 != D or H1 != H or W1 != W:
raise ValueError("Bad volume_features shape")
if grid_sizes.dtype != torch.int64:
raise ValueError("grid_sizes must be int64")
N4, D1 = grid_sizes.shape
if N4 != N or D1 != 3:
raise ValueError("Bad grid_sizes.shape")
if mask.dtype != torch.float32:
raise ValueError("mask must be float32")
N5, P2 = mask.shape
if N5 != N or P2 != P:
raise ValueError("Bad mask shape")
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
_C.points_to_volumes_forward(
points_3d,
points_features,
volume_densities,
volume_features,
grid_sizes,
mask,
point_weight,
align_corners,
splat,
)
if splat:
ctx.save_for_backward(points_3d, points_features, grid_sizes, mask)
else:
ctx.save_for_backward(points_3d, grid_sizes, mask)
ctx.point_weight = point_weight
ctx.splat = splat
ctx.align_corners = align_corners
return volume_densities, volume_features
@staticmethod
@once_differentiable
def backward(ctx, grad_volume_densities, grad_volume_features):
splat = ctx.splat
N, C = grad_volume_features.shape[:2]
if splat:
points_3d, points_features, grid_sizes, mask = ctx.saved_tensors
P = points_3d.shape[1]
grad_points_3d = torch.zeros_like(points_3d)
else:
points_3d, grid_sizes, mask = ctx.saved_tensors
P = points_3d.shape[1]
ones = points_3d.new_zeros(1, 1, 1)
# There is no gradient. Just need something to let its accessors exist.
grad_points_3d = ones.expand_as(points_3d)
# points_features not needed. Just need something to let its accessors exist.
points_features = ones.expand(N, P, C)
grad_points_features = points_3d.new_zeros(N, P, C)
_C.points_to_volumes_backward(
points_3d,
points_features,
grid_sizes,
mask,
ctx.point_weight,
ctx.align_corners,
splat,
grad_volume_densities,
grad_volume_features,
grad_points_3d,
grad_points_features,
)
return (
(grad_points_3d if splat else None),
grad_points_features,
grad_volume_densities,
grad_volume_features,
None,
None,
None,
None,
None,
)
# pyre-fixme[16]: `_points_to_volumes_function` has no attribute `apply`.
_points_to_volumes = _points_to_volumes_function.apply
def add_pointclouds_to_volumes(
pointclouds: "Pointclouds",
initial_volumes: "Volumes",