mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-07 14:42:22 +08:00
CPU function for points2vols
Summary: Single C++ function for the core of points2vols, not used anywhere yet. Added ability to control align_corners and the weight of each point, which may be useful later. Reviewed By: nikhilaravi Differential Revision: D29548607 fbshipit-source-id: a5cda7ec2c14836624e7dfe744c4bbb3f3d3dfe2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c7c6deab86
commit
0dfc6e0eb8
@@ -7,12 +7,186 @@
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..structures import Pointclouds, Volumes
|
||||
|
||||
|
||||
class _points_to_volumes_function(Function):
|
||||
"""
|
||||
For each point in a pointcloud, add point_weight to the
|
||||
corresponding volume density and point_weight times its features
|
||||
to the corresponding volume features.
|
||||
|
||||
This function does not require any contiguity internally and therefore
|
||||
doesn't need to make copies of its inputs, which is useful when GPU memory
|
||||
is at a premium. (An implementation requiring contiguous inputs might be faster
|
||||
though). The volumes are modified in place.
|
||||
|
||||
This function is differentiable with respect to
|
||||
points_features, volume_densities and volume_features.
|
||||
If splat is True then it is also differentiable with respect to
|
||||
points_3d.
|
||||
|
||||
It may be useful to think about this function as a sort of opposite to
|
||||
torch.nn.functional.grid_sample with 5D inputs.
|
||||
|
||||
Args:
|
||||
points_3d: Batch of 3D point cloud coordinates of shape
|
||||
`(minibatch, N, 3)` where N is the number of points
|
||||
in each point cloud. Coordinates have to be specified in the
|
||||
local volume coordinates (ranging in [-1, 1]).
|
||||
points_features: Features of shape `(minibatch, N, feature_dim)`
|
||||
corresponding to the points of the input point cloud `points_3d`.
|
||||
volume_features: Batch of input feature volumes
|
||||
of shape `(minibatch, feature_dim, D, H, W)`
|
||||
volume_densities: Batch of input feature volume densities
|
||||
of shape `(minibatch, 1, D, H, W)`. Each voxel should
|
||||
contain a non-negative number corresponding to its
|
||||
opaqueness (the higher, the less transparent).
|
||||
|
||||
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
|
||||
spatial resolutions of each of the the non-flattened `volumes`
|
||||
tensors. Note that the following has to hold:
|
||||
`torch.prod(grid_sizes, dim=1)==N_voxels`.
|
||||
|
||||
point_weight: A scalar controlling how much weight a single point has.
|
||||
|
||||
mask: A binary mask of shape `(minibatch, N)` determining
|
||||
which 3D points are going to be converted to the resulting
|
||||
volume. Set to `None` if all points are valid.
|
||||
|
||||
align_corners: as for grid_sample.
|
||||
|
||||
splat: if true, trilinear interpolation. If false all the weight goes in
|
||||
the nearest voxel.
|
||||
|
||||
Returns:
|
||||
volume_densities and volume_features, which have been modified in place.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
|
||||
def forward(
|
||||
ctx,
|
||||
points_3d: torch.Tensor,
|
||||
points_features: torch.Tensor,
|
||||
volume_densities: torch.Tensor,
|
||||
volume_features: torch.Tensor,
|
||||
grid_sizes: torch.LongTensor,
|
||||
point_weight: float,
|
||||
mask: torch.Tensor,
|
||||
align_corners: bool,
|
||||
splat: bool,
|
||||
):
|
||||
|
||||
ctx.mark_dirty(volume_densities, volume_features)
|
||||
|
||||
N, P, D = points_3d.shape
|
||||
if D != 3:
|
||||
raise ValueError("points_3d must be 3D")
|
||||
if points_3d.dtype != torch.float32:
|
||||
raise ValueError("points_3d must be float32")
|
||||
if points_features.dtype != torch.float32:
|
||||
raise ValueError("points_features must be float32")
|
||||
N1, P1, C = points_features.shape
|
||||
if N1 != N or P1 != P:
|
||||
raise ValueError("Bad points_features shape")
|
||||
if volume_densities.dtype != torch.float32:
|
||||
raise ValueError("volume_densities must be float32")
|
||||
N2, one, D, H, W = volume_densities.shape
|
||||
if N2 != N or one != 1:
|
||||
raise ValueError("Bad volume_densities shape")
|
||||
if volume_features.dtype != torch.float32:
|
||||
raise ValueError("volume_features must be float32")
|
||||
N3, C1, D1, H1, W1 = volume_features.shape
|
||||
if N3 != N or C1 != C or D1 != D or H1 != H or W1 != W:
|
||||
raise ValueError("Bad volume_features shape")
|
||||
if grid_sizes.dtype != torch.int64:
|
||||
raise ValueError("grid_sizes must be int64")
|
||||
N4, D1 = grid_sizes.shape
|
||||
if N4 != N or D1 != 3:
|
||||
raise ValueError("Bad grid_sizes.shape")
|
||||
if mask.dtype != torch.float32:
|
||||
raise ValueError("mask must be float32")
|
||||
N5, P2 = mask.shape
|
||||
if N5 != N or P2 != P:
|
||||
raise ValueError("Bad mask shape")
|
||||
|
||||
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
|
||||
_C.points_to_volumes_forward(
|
||||
points_3d,
|
||||
points_features,
|
||||
volume_densities,
|
||||
volume_features,
|
||||
grid_sizes,
|
||||
mask,
|
||||
point_weight,
|
||||
align_corners,
|
||||
splat,
|
||||
)
|
||||
if splat:
|
||||
ctx.save_for_backward(points_3d, points_features, grid_sizes, mask)
|
||||
else:
|
||||
ctx.save_for_backward(points_3d, grid_sizes, mask)
|
||||
ctx.point_weight = point_weight
|
||||
ctx.splat = splat
|
||||
ctx.align_corners = align_corners
|
||||
return volume_densities, volume_features
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_volume_densities, grad_volume_features):
|
||||
splat = ctx.splat
|
||||
N, C = grad_volume_features.shape[:2]
|
||||
if splat:
|
||||
points_3d, points_features, grid_sizes, mask = ctx.saved_tensors
|
||||
P = points_3d.shape[1]
|
||||
grad_points_3d = torch.zeros_like(points_3d)
|
||||
else:
|
||||
points_3d, grid_sizes, mask = ctx.saved_tensors
|
||||
P = points_3d.shape[1]
|
||||
ones = points_3d.new_zeros(1, 1, 1)
|
||||
# There is no gradient. Just need something to let its accessors exist.
|
||||
grad_points_3d = ones.expand_as(points_3d)
|
||||
# points_features not needed. Just need something to let its accessors exist.
|
||||
points_features = ones.expand(N, P, C)
|
||||
grad_points_features = points_3d.new_zeros(N, P, C)
|
||||
_C.points_to_volumes_backward(
|
||||
points_3d,
|
||||
points_features,
|
||||
grid_sizes,
|
||||
mask,
|
||||
ctx.point_weight,
|
||||
ctx.align_corners,
|
||||
splat,
|
||||
grad_volume_densities,
|
||||
grad_volume_features,
|
||||
grad_points_3d,
|
||||
grad_points_features,
|
||||
)
|
||||
|
||||
return (
|
||||
(grad_points_3d if splat else None),
|
||||
grad_points_features,
|
||||
grad_volume_densities,
|
||||
grad_volume_features,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# pyre-fixme[16]: `_points_to_volumes_function` has no attribute `apply`.
|
||||
_points_to_volumes = _points_to_volumes_function.apply
|
||||
|
||||
|
||||
def add_pointclouds_to_volumes(
|
||||
pointclouds: "Pointclouds",
|
||||
initial_volumes: "Volumes",
|
||||
|
||||
Reference in New Issue
Block a user