mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Point clouds to volumes
Summary: Conversion from point clouds to volumes ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- ADD_POINTS_TO_VOLUMES_10_trilinear_[25, 25, 25]_1000 43219 44067 12 ADD_POINTS_TO_VOLUMES_10_trilinear_[25, 25, 25]_10000 43274 45313 12 ADD_POINTS_TO_VOLUMES_10_trilinear_[25, 25, 25]_100000 46281 47100 11 ADD_POINTS_TO_VOLUMES_10_trilinear_[101, 111, 121]_1000 51224 51912 10 ADD_POINTS_TO_VOLUMES_10_trilinear_[101, 111, 121]_10000 52092 54487 10 ADD_POINTS_TO_VOLUMES_10_trilinear_[101, 111, 121]_100000 59262 60514 9 ADD_POINTS_TO_VOLUMES_10_nearest_[25, 25, 25]_1000 15998 17237 32 ADD_POINTS_TO_VOLUMES_10_nearest_[25, 25, 25]_10000 15964 16994 32 ADD_POINTS_TO_VOLUMES_10_nearest_[25, 25, 25]_100000 16881 19286 30 ADD_POINTS_TO_VOLUMES_10_nearest_[101, 111, 121]_1000 19150 25277 27 ADD_POINTS_TO_VOLUMES_10_nearest_[101, 111, 121]_10000 18746 19999 27 ADD_POINTS_TO_VOLUMES_10_nearest_[101, 111, 121]_100000 22321 24568 23 ADD_POINTS_TO_VOLUMES_100_trilinear_[25, 25, 25]_1000 49693 50288 11 ADD_POINTS_TO_VOLUMES_100_trilinear_[25, 25, 25]_10000 51429 52449 10 ADD_POINTS_TO_VOLUMES_100_trilinear_[25, 25, 25]_100000 237076 237377 3 ADD_POINTS_TO_VOLUMES_100_trilinear_[101, 111, 121]_1000 81875 82597 7 ADD_POINTS_TO_VOLUMES_100_trilinear_[101, 111, 121]_10000 106671 107045 5 ADD_POINTS_TO_VOLUMES_100_trilinear_[101, 111, 121]_100000 483740 484607 2 ADD_POINTS_TO_VOLUMES_100_nearest_[25, 25, 25]_1000 16667 18143 31 ADD_POINTS_TO_VOLUMES_100_nearest_[25, 25, 25]_10000 17682 18922 29 ADD_POINTS_TO_VOLUMES_100_nearest_[25, 25, 25]_100000 65463 67116 8 ADD_POINTS_TO_VOLUMES_100_nearest_[101, 111, 121]_1000 48058 48826 11 ADD_POINTS_TO_VOLUMES_100_nearest_[101, 111, 121]_10000 53529 53998 10 ADD_POINTS_TO_VOLUMES_100_nearest_[101, 111, 121]_100000 123684 123901 5 -------------------------------------------------------------------------------- ``` Output with `DEBUG=True` {F338561209} Reviewed By: nikhilaravi Differential Revision: D22017500 fbshipit-source-id: ed3e8ed13940c593841d93211623dd533974012f
This commit is contained in:
parent
03ee1dbf82
commit
aa9bcaf04c
@ -14,6 +14,10 @@ from .points_normals import (
|
|||||||
estimate_pointcloud_local_coord_frames,
|
estimate_pointcloud_local_coord_frames,
|
||||||
estimate_pointcloud_normals,
|
estimate_pointcloud_normals,
|
||||||
)
|
)
|
||||||
|
from .points_to_volumes import (
|
||||||
|
add_pointclouds_to_volumes,
|
||||||
|
add_points_features_to_volume_densities_features,
|
||||||
|
)
|
||||||
from .sample_points_from_meshes import sample_points_from_meshes
|
from .sample_points_from_meshes import sample_points_from_meshes
|
||||||
from .subdivide_meshes import SubdivideMeshes
|
from .subdivide_meshes import SubdivideMeshes
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
491
pytorch3d/ops/points_to_volumes.py
Normal file
491
pytorch3d/ops/points_to_volumes.py
Normal file
@ -0,0 +1,491 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..structures import Pointclouds, Volumes
|
||||||
|
|
||||||
|
|
||||||
|
def add_pointclouds_to_volumes(
|
||||||
|
pointclouds: "Pointclouds",
|
||||||
|
initial_volumes: "Volumes",
|
||||||
|
mode: str = "trilinear",
|
||||||
|
min_weight: float = 1e-4,
|
||||||
|
) -> "Volumes":
|
||||||
|
"""
|
||||||
|
Add a batch of point clouds represented with a `Pointclouds` structure
|
||||||
|
`pointclouds` to a batch of existing volumes represented with a
|
||||||
|
`Volumes` structure `initial_volumes`.
|
||||||
|
|
||||||
|
More specifically, the method casts a set of weighted votes (the weights are
|
||||||
|
determined based on `mode="trilinear"|"nearest"`) into the pre-initialized
|
||||||
|
`features` and `densities` fields of `initial_volumes`.
|
||||||
|
|
||||||
|
The method returns an updated `Volumes` object that contains a copy
|
||||||
|
of `initial_volumes` with its `features` and `densities` updated with the
|
||||||
|
result of the pointcloud addition.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
# init a random point cloud
|
||||||
|
pointclouds = Pointclouds(
|
||||||
|
points=torch.randn(4, 100, 3), features=torch.rand(4, 100, 5)
|
||||||
|
)
|
||||||
|
# init an empty volume centered around [0.5, 0.5, 0.5] in world coordinates
|
||||||
|
# with a voxel size of 1.0.
|
||||||
|
initial_volumes = Volumes(
|
||||||
|
features = torch.zeros(4, 5, 25, 25, 25),
|
||||||
|
densities = torch.zeros(4, 1, 25, 25, 25),
|
||||||
|
volume_translation = [-0.5, -0.5, -0.5],
|
||||||
|
voxel_size = 1.0,
|
||||||
|
)
|
||||||
|
# add the pointcloud to the 'initial_volumes' buffer using
|
||||||
|
# trilinear splatting
|
||||||
|
updated_volumes = add_pointclouds_to_volumes(
|
||||||
|
pointclouds=pointclouds,
|
||||||
|
initial_volumes=initial_volumes,
|
||||||
|
mode="trilinear",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pointclouds: Batch of 3D pointclouds represented with a `Pointclouds`
|
||||||
|
structure. Note that `pointclouds.features` have to be defined.
|
||||||
|
initial_volumes: Batch of initial `Volumes` with pre-initialized 1-dimensional
|
||||||
|
densities which contain non-negative numbers corresponding to the
|
||||||
|
opaqueness of each voxel (the higher, the less transparent).
|
||||||
|
mode: The mode of the conversion of individual points into the volume.
|
||||||
|
Set either to `nearest` or `trilinear`:
|
||||||
|
`nearest`: Each 3D point is first rounded to the volumetric
|
||||||
|
lattice. Each voxel is then labeled with the average
|
||||||
|
over features that fall into the given voxel.
|
||||||
|
The gradients of nearest neighbor conversion w.r.t. the
|
||||||
|
3D locations of the points in `pointclouds` are *not* defined.
|
||||||
|
`trilinear`: Each 3D point casts 8 weighted votes to the 8-neighborhood
|
||||||
|
of its floating point coordinate. The weights are
|
||||||
|
determined using a trilinear interpolation scheme.
|
||||||
|
Trilinear splatting is fully differentiable w.r.t. all input arguments.
|
||||||
|
min_weight: A scalar controlling the lowest possible total per-voxel
|
||||||
|
weight used to normalize the features accumulated in a voxel.
|
||||||
|
Only active for `mode==trilinear`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
updated_volumes: Output `Volumes` structure containing the conversion result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(initial_volumes) != len(pointclouds):
|
||||||
|
raise ValueError(
|
||||||
|
"'initial_volumes' and 'pointclouds' have to have the same batch size."
|
||||||
|
)
|
||||||
|
|
||||||
|
# obtain the features and densities
|
||||||
|
pcl_feats = pointclouds.features_padded()
|
||||||
|
pcl_3d = pointclouds.points_padded()
|
||||||
|
|
||||||
|
if pcl_feats is None:
|
||||||
|
raise ValueError("'pointclouds' have to have their 'features' defined.")
|
||||||
|
|
||||||
|
# obtain the conversion mask
|
||||||
|
n_per_pcl = pointclouds.num_points_per_cloud().type_as(pcl_feats)
|
||||||
|
mask = torch.arange(n_per_pcl.max(), dtype=pcl_feats.dtype, device=pcl_feats.device)
|
||||||
|
mask = (mask[None, :] < n_per_pcl[:, None]).type_as(mask)
|
||||||
|
|
||||||
|
# convert to the coord frame of the volume
|
||||||
|
pcl_3d_local = initial_volumes.world_to_local_coords(pcl_3d)
|
||||||
|
|
||||||
|
features_new, densities_new = add_points_features_to_volume_densities_features(
|
||||||
|
points_3d=pcl_3d_local,
|
||||||
|
points_features=pcl_feats,
|
||||||
|
volume_features=initial_volumes.features(),
|
||||||
|
volume_densities=initial_volumes.densities(),
|
||||||
|
min_weight=min_weight,
|
||||||
|
grid_sizes=initial_volumes.get_grid_sizes(),
|
||||||
|
mask=mask,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
return initial_volumes.update_padded(
|
||||||
|
new_densities=densities_new, new_features=features_new
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_points_features_to_volume_densities_features(
|
||||||
|
points_3d: torch.Tensor,
|
||||||
|
points_features: torch.Tensor,
|
||||||
|
volume_densities: torch.Tensor,
|
||||||
|
volume_features: Optional[torch.Tensor],
|
||||||
|
mode: str = "trilinear",
|
||||||
|
min_weight: float = 1e-4,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
grid_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Convert a batch of point clouds represented with tensors of per-point
|
||||||
|
3d coordinates and their features to a batch of volumes represented
|
||||||
|
with tensors of densities and features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points_3d: Batch of 3D point cloud coordinates of shape
|
||||||
|
`(minibatch, N, 3)` where N is the number of points
|
||||||
|
in each point cloud. Coordinates have to be specified in the
|
||||||
|
local volume coordinates (ranging in [-1, 1]).
|
||||||
|
points_features: Features of shape `(minibatch, N, feature_dim)` corresponding
|
||||||
|
to the points of the input point clouds `pointcloud`.
|
||||||
|
volume_densities: Batch of input feature volume densities of shape
|
||||||
|
`(minibatch, 1, D, H, W)`. Each voxel should
|
||||||
|
contain a non-negative number corresponding to its
|
||||||
|
opaqueness (the higher, the less transparent).
|
||||||
|
volume_features: Batch of input feature volumes of shape
|
||||||
|
`(minibatch, feature_dim, D, H, W)`
|
||||||
|
If set to `None`, the `volume_features` will be automatically
|
||||||
|
instantiatied with a correct size and filled with 0s.
|
||||||
|
mode: The mode of the conversion of individual points into the volume.
|
||||||
|
Set either to `nearest` or `trilinear`:
|
||||||
|
`nearest`: Each 3D point is first rounded to the volumetric
|
||||||
|
lattice. Each voxel is then labeled with the average
|
||||||
|
over features that fall into the given voxel.
|
||||||
|
The gradients of nearest neighbor rounding w.r.t. the
|
||||||
|
input point locations `points_3d` are *not* defined.
|
||||||
|
`trilinear`: Each 3D point casts 8 weighted votes to the 8-neighborhood
|
||||||
|
of its floating point coordinate. The weights are
|
||||||
|
determined using a trilinear interpolation scheme.
|
||||||
|
Trilinear splatting is fully differentiable w.r.t. all input arguments.
|
||||||
|
mask: A binary mask of shape `(minibatch, N)` determining which 3D points
|
||||||
|
are going to be converted to the resulting volume.
|
||||||
|
Set to `None` if all points are valid.
|
||||||
|
min_weight: A scalar controlling the lowest possible total per-voxel
|
||||||
|
weight used to normalize the features accumulated in a voxel.
|
||||||
|
Only active for `mode==trilinear`.
|
||||||
|
Returns:
|
||||||
|
volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)`
|
||||||
|
volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)`
|
||||||
|
containing the total amount of votes cast to each of the voxels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# number of points in the point cloud, its dim and batch size
|
||||||
|
ba, n_points, feature_dim = points_features.shape
|
||||||
|
ba_volume, density_dim = volume_densities.shape[:2]
|
||||||
|
|
||||||
|
if density_dim != 1:
|
||||||
|
raise ValueError("Only one-dimensional densities are allowed.")
|
||||||
|
|
||||||
|
# init the volumetric grid sizes if uninitialized
|
||||||
|
if grid_sizes is None:
|
||||||
|
grid_sizes = torch.LongTensor(list(volume_densities.shape[2:])).to(
|
||||||
|
volume_densities
|
||||||
|
)
|
||||||
|
|
||||||
|
# flatten densities and features
|
||||||
|
v_shape = volume_densities.shape[2:]
|
||||||
|
volume_densities_flatten = volume_densities.view(ba, -1, 1)
|
||||||
|
n_voxels = volume_densities_flatten.shape[1]
|
||||||
|
|
||||||
|
if volume_features is None:
|
||||||
|
# initialize features if not passed in
|
||||||
|
volume_features_flatten = volume_densities.new_zeros(ba, feature_dim, n_voxels)
|
||||||
|
else:
|
||||||
|
# otherwise just flatten
|
||||||
|
volume_features_flatten = volume_features.view(ba, feature_dim, n_voxels)
|
||||||
|
|
||||||
|
if mode == "trilinear": # do the splatting (trilinear interp)
|
||||||
|
volume_features, volume_densities = splat_points_to_volumes(
|
||||||
|
points_3d,
|
||||||
|
points_features,
|
||||||
|
volume_densities_flatten,
|
||||||
|
volume_features_flatten,
|
||||||
|
grid_sizes,
|
||||||
|
mask=mask,
|
||||||
|
min_weight=min_weight,
|
||||||
|
)
|
||||||
|
elif mode == "nearest": # nearest neighbor interp
|
||||||
|
volume_features, volume_densities = round_points_to_volumes(
|
||||||
|
points_3d,
|
||||||
|
points_features,
|
||||||
|
volume_densities_flatten,
|
||||||
|
volume_features_flatten,
|
||||||
|
grid_sizes,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('No such interpolation mode "%s"' % mode)
|
||||||
|
|
||||||
|
# reshape into the volume shape
|
||||||
|
volume_features = volume_features.view(ba, feature_dim, *v_shape)
|
||||||
|
volume_densities = volume_densities.view(ba, 1, *v_shape)
|
||||||
|
|
||||||
|
return volume_features, volume_densities
|
||||||
|
|
||||||
|
|
||||||
|
def _check_points_to_volumes_inputs(
|
||||||
|
points_3d: torch.Tensor,
|
||||||
|
points_features: torch.Tensor,
|
||||||
|
volume_densities: torch.Tensor,
|
||||||
|
volume_features: torch.Tensor,
|
||||||
|
grid_sizes: torch.LongTensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
max_grid_size = grid_sizes.max(dim=0).values
|
||||||
|
if torch.prod(max_grid_size) > volume_densities.shape[1]:
|
||||||
|
raise ValueError(
|
||||||
|
"One of the grid sizes corresponds to a larger number"
|
||||||
|
+ " of elements than the number of elements in volume_densities."
|
||||||
|
)
|
||||||
|
|
||||||
|
_, n_voxels, density_dim = volume_densities.shape
|
||||||
|
|
||||||
|
if density_dim != 1:
|
||||||
|
raise ValueError("Only one-dimensional densities are allowed.")
|
||||||
|
|
||||||
|
ba, n_points, feature_dim = points_features.shape
|
||||||
|
|
||||||
|
if volume_features.shape[1] != feature_dim:
|
||||||
|
raise ValueError(
|
||||||
|
"volume_features have a different number of channels"
|
||||||
|
+ " than points_features."
|
||||||
|
)
|
||||||
|
|
||||||
|
if volume_features.shape[2] != n_voxels:
|
||||||
|
raise ValueError(
|
||||||
|
"volume_features have a different number of elements"
|
||||||
|
+ " than volume_densities."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def splat_points_to_volumes(
|
||||||
|
points_3d: torch.Tensor,
|
||||||
|
points_features: torch.Tensor,
|
||||||
|
volume_densities: torch.Tensor,
|
||||||
|
volume_features: torch.Tensor,
|
||||||
|
grid_sizes: torch.LongTensor,
|
||||||
|
min_weight: float = 1e-4,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Convert a batch of point clouds to a batch of volumes using trilinear
|
||||||
|
splatting into a volume.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points_3d: Batch of 3D point cloud coordinates of shape
|
||||||
|
`(minibatch, N, 3)` where N is the number of points
|
||||||
|
in each point cloud. Coordinates have to be specified in the
|
||||||
|
local volume coordinates (ranging in [-1, 1]).
|
||||||
|
points_features: Features of shape `(minibatch, N, feature_dim)`
|
||||||
|
corresponding to the points of the input point cloud `points_3d`.
|
||||||
|
volume_features: Batch of input *flattened* feature volumes
|
||||||
|
of shape `(minibatch, feature_dim, N_voxels)`
|
||||||
|
volume_densities: Batch of input *flattened* feature volume densities
|
||||||
|
of shape `(minibatch, 1, N_voxels)`. Each voxel should
|
||||||
|
contain a non-negative number corresponding to its
|
||||||
|
opaqueness (the higher, the less transparent).
|
||||||
|
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
|
||||||
|
spatial resolutions of each of the the non-flattened `volumes` tensors.
|
||||||
|
Note that the following has to hold:
|
||||||
|
`torch.prod(grid_sizes, dim=1)==N_voxels`
|
||||||
|
mask: A binary mask of shape `(minibatch, N)` determining which 3D points
|
||||||
|
are going to be converted to the resulting volume.
|
||||||
|
Set to `None` if all points are valid.
|
||||||
|
Returns:
|
||||||
|
volume_features: Output volume of shape `(minibatch, D, N_voxels)`.
|
||||||
|
volume_densities: Occupancy volume of shape `(minibatch, 1, N_voxels)`
|
||||||
|
containing the total amount of votes cast to each of the voxels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_check_points_to_volumes_inputs(
|
||||||
|
points_3d,
|
||||||
|
points_features,
|
||||||
|
volume_densities,
|
||||||
|
volume_features,
|
||||||
|
grid_sizes,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, n_voxels, density_dim = volume_densities.shape
|
||||||
|
ba, n_points, feature_dim = points_features.shape
|
||||||
|
|
||||||
|
# minibatch x n_points x feature_dim -> minibatch x feature_dim x n_points
|
||||||
|
points_features = points_features.permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
# XYZ = the upper-left volume index of the 8-neigborhood of every point
|
||||||
|
# grid_sizes is of the form (minibatch, depth-height-width)
|
||||||
|
grid_sizes_xyz = grid_sizes[:, [2, 1, 0]]
|
||||||
|
|
||||||
|
# Convert from points_3d in the range [-1, 1] to
|
||||||
|
# indices in the volume grid in the range [0, grid_sizes_xyz-1]
|
||||||
|
points_3d_indices = ((points_3d + 1) * 0.5) * (
|
||||||
|
grid_sizes_xyz[:, None].type_as(points_3d) - 1
|
||||||
|
)
|
||||||
|
XYZ = points_3d_indices.floor().long()
|
||||||
|
rXYZ = points_3d_indices - XYZ.type_as(points_3d) # remainder of floor
|
||||||
|
|
||||||
|
# split into separate coordinate vectors
|
||||||
|
X, Y, Z = XYZ.split(1, dim=2)
|
||||||
|
# rX = remainder after floor = 1-"the weight of each vote into
|
||||||
|
# the X coordinate of the 8-neighborhood"
|
||||||
|
rX, rY, rZ = rXYZ.split(1, dim=2)
|
||||||
|
|
||||||
|
# get random indices for the purpose of adding out-of-bounds values
|
||||||
|
rand_idx = X.new_zeros(X.shape).random_(0, n_voxels)
|
||||||
|
|
||||||
|
# iterate over the x, y, z indices of the 8-neighborhood (xdiff, ydiff, zdiff)
|
||||||
|
for xdiff in (0, 1):
|
||||||
|
X_ = X + xdiff
|
||||||
|
wX = (1 - xdiff) + (2 * xdiff - 1) * rX
|
||||||
|
for ydiff in (0, 1):
|
||||||
|
Y_ = Y + ydiff
|
||||||
|
wY = (1 - ydiff) + (2 * ydiff - 1) * rY
|
||||||
|
for zdiff in (0, 1):
|
||||||
|
Z_ = Z + zdiff
|
||||||
|
wZ = (1 - zdiff) + (2 * zdiff - 1) * rZ
|
||||||
|
|
||||||
|
# weight of each vote into the given cell of 8-neighborhood
|
||||||
|
w = wX * wY * wZ
|
||||||
|
|
||||||
|
# valid - binary indicators of votes that fall into the volume
|
||||||
|
valid = (
|
||||||
|
(0 <= X_)
|
||||||
|
* (X_ < grid_sizes_xyz[:, None, 0:1])
|
||||||
|
* (0 <= Y_)
|
||||||
|
* (Y_ < grid_sizes_xyz[:, None, 1:2])
|
||||||
|
* (0 <= Z_)
|
||||||
|
* (Z_ < grid_sizes_xyz[:, None, 2:3])
|
||||||
|
).long()
|
||||||
|
|
||||||
|
# linearized indices into the volume
|
||||||
|
idx = (Z_ * grid_sizes[:, None, 1:2] + Y_) * grid_sizes[
|
||||||
|
:, None, 2:3
|
||||||
|
] + X_
|
||||||
|
|
||||||
|
# out-of-bounds features added to a random voxel idx with weight=0.
|
||||||
|
idx_valid = idx * valid + rand_idx * (1 - valid)
|
||||||
|
w_valid = w * valid.type_as(w)
|
||||||
|
if mask is not None:
|
||||||
|
w_valid = w_valid * mask.type_as(w)[:, :, None]
|
||||||
|
|
||||||
|
# scatter add casts the votes into the weight accumulator
|
||||||
|
# and the feature accumulator
|
||||||
|
volume_densities.scatter_add_(1, idx_valid, w_valid)
|
||||||
|
|
||||||
|
# reshape idx_valid -> (minibatch, feature_dim, n_points)
|
||||||
|
idx_valid = idx_valid.view(ba, 1, n_points).expand_as(points_features)
|
||||||
|
w_valid = w_valid.view(ba, 1, n_points)
|
||||||
|
|
||||||
|
# volume_features of shape (minibatch, feature_dim, n_voxels)
|
||||||
|
volume_features.scatter_add_(2, idx_valid, w_valid * points_features)
|
||||||
|
|
||||||
|
# divide each feature by the total weight of the votes
|
||||||
|
volume_features = volume_features / volume_densities.view(ba, 1, n_voxels).clamp(
|
||||||
|
min_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
return volume_features, volume_densities
|
||||||
|
|
||||||
|
|
||||||
|
def round_points_to_volumes(
|
||||||
|
points_3d: torch.Tensor,
|
||||||
|
points_features: torch.Tensor,
|
||||||
|
volume_densities: torch.Tensor,
|
||||||
|
volume_features: torch.Tensor,
|
||||||
|
grid_sizes: torch.LongTensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Convert a batch of point clouds to a batch of volumes using rounding to the
|
||||||
|
nearest integer coordinate of the volume. Features that fall into the same
|
||||||
|
voxel are averaged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points_3d: Batch of 3D point cloud coordinates of shape
|
||||||
|
`(minibatch, N, 3)` where N is the number of points
|
||||||
|
in each point cloud. Coordinates have to be specified in the
|
||||||
|
local volume coordinates (ranging in [-1, 1]).
|
||||||
|
points_features: Features of shape `(minibatch, N, feature_dim)`
|
||||||
|
corresponding to the points of the input point cloud `points_3d`.
|
||||||
|
volume_features: Batch of input *flattened* feature volumes
|
||||||
|
of shape `(minibatch, feature_dim, N_voxels)`
|
||||||
|
volume_densities: Batch of input *flattened* feature volume densities
|
||||||
|
of shape `(minibatch, 1, N_voxels)`. Each voxel should
|
||||||
|
contain a non-negative number corresponding to its
|
||||||
|
opaqueness (the higher, the less transparent).
|
||||||
|
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
|
||||||
|
spatial resolutions of each of the the non-flattened `volumes` tensors.
|
||||||
|
Note that the following has to hold:
|
||||||
|
`torch.prod(grid_sizes, dim=1)==N_voxels`
|
||||||
|
mask: A binary mask of shape `(minibatch, N)` determining which 3D points
|
||||||
|
are going to be converted to the resulting volume.
|
||||||
|
Set to `None` if all points are valid.
|
||||||
|
Returns:
|
||||||
|
volume_features: Output volume of shape `(minibatch, D, N_voxels)`.
|
||||||
|
volume_densities: Occupancy volume of shape `(minibatch, 1, N_voxels)`
|
||||||
|
containing the total amount of votes cast to each of the voxels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_check_points_to_volumes_inputs(
|
||||||
|
points_3d,
|
||||||
|
points_features,
|
||||||
|
volume_densities,
|
||||||
|
volume_features,
|
||||||
|
grid_sizes,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, n_voxels, density_dim = volume_densities.shape
|
||||||
|
ba, n_points, feature_dim = points_features.shape
|
||||||
|
|
||||||
|
# minibatch x n_points x feature_dim-> minibatch x feature_dim x n_points
|
||||||
|
points_features = points_features.permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
# round the coordinates to nearest integer
|
||||||
|
# grid_sizes is of the form (minibatch, depth-height-width)
|
||||||
|
grid_sizes_xyz = grid_sizes[:, [2, 1, 0]]
|
||||||
|
XYZ = ((points_3d.detach() + 1) * 0.5) * (
|
||||||
|
grid_sizes_xyz[:, None].type_as(points_3d) - 1
|
||||||
|
)
|
||||||
|
XYZ = torch.round(XYZ).long()
|
||||||
|
|
||||||
|
# split into separate coordinate vectors
|
||||||
|
X, Y, Z = XYZ.split(1, dim=2)
|
||||||
|
|
||||||
|
# get random indices for the purpose of adding out-of-bounds values
|
||||||
|
rand_idx = X.new_zeros(X.shape).random_(0, n_voxels)
|
||||||
|
|
||||||
|
# valid - binary indicators of votes that fall into the volume
|
||||||
|
grid_sizes = grid_sizes.type_as(XYZ)
|
||||||
|
valid = (
|
||||||
|
(0 <= X)
|
||||||
|
* (X < grid_sizes_xyz[:, None, 0:1])
|
||||||
|
* (0 <= Y)
|
||||||
|
* (Y < grid_sizes_xyz[:, None, 1:2])
|
||||||
|
* (0 <= Z)
|
||||||
|
* (Z < grid_sizes_xyz[:, None, 2:3])
|
||||||
|
).long()
|
||||||
|
|
||||||
|
# get random indices for the purpose of adding out-of-bounds values
|
||||||
|
rand_idx = valid.new_zeros(X.shape).random_(0, n_voxels)
|
||||||
|
|
||||||
|
# linearized indices into the volume
|
||||||
|
idx = (Z * grid_sizes[:, None, 1:2] + Y) * grid_sizes[:, None, 2:3] + X
|
||||||
|
|
||||||
|
# out-of-bounds features added to a random voxel idx with weight=0.
|
||||||
|
idx_valid = idx * valid + rand_idx * (1 - valid)
|
||||||
|
w_valid = valid.type_as(volume_features)
|
||||||
|
|
||||||
|
# scatter add casts the votes into the weight accumulator
|
||||||
|
# and the feature accumulator
|
||||||
|
volume_densities.scatter_add_(1, idx_valid, w_valid)
|
||||||
|
|
||||||
|
# reshape idx_valid -> (minibatch, feature_dim, n_points)
|
||||||
|
idx_valid = idx_valid.view(ba, 1, n_points).expand_as(points_features)
|
||||||
|
w_valid = w_valid.view(ba, 1, n_points)
|
||||||
|
|
||||||
|
# volume_features of shape (minibatch, feature_dim, n_voxels)
|
||||||
|
volume_features.scatter_add_(2, idx_valid, w_valid * points_features)
|
||||||
|
|
||||||
|
# divide each feature by the total weight of the votes
|
||||||
|
volume_features = volume_features / volume_densities.view(ba, 1, n_voxels).clamp(
|
||||||
|
1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
return volume_features, volume_densities
|
24
tests/bm_points_to_volumes.py
Normal file
24
tests/bm_points_to_volumes.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
from fvcore.common.benchmark import benchmark
|
||||||
|
from test_points_to_volumes import TestPointsToVolumes
|
||||||
|
|
||||||
|
|
||||||
|
def bm_points_to_volumes() -> None:
|
||||||
|
case_grid = {
|
||||||
|
"batch_size": [10, 100],
|
||||||
|
"interp_mode": ["trilinear", "nearest"],
|
||||||
|
"volume_size": [[25, 25, 25], [101, 111, 121]],
|
||||||
|
"n_points": [1000, 10000, 100000],
|
||||||
|
}
|
||||||
|
test_cases = itertools.product(*case_grid.values())
|
||||||
|
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
|
||||||
|
|
||||||
|
benchmark(
|
||||||
|
TestPointsToVolumes.add_points_to_volumes,
|
||||||
|
"ADD_POINTS_TO_VOLUMES",
|
||||||
|
kwargs_list,
|
||||||
|
warmup_iters=1,
|
||||||
|
)
|
385
tests/test_points_to_volumes.py
Normal file
385
tests/test_points_to_volumes.py
Normal file
@ -0,0 +1,385 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
import unittest
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from common_testing import TestCaseMixin
|
||||||
|
from pytorch3d.ops import add_pointclouds_to_volumes
|
||||||
|
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
|
||||||
|
from pytorch3d.structures.meshes import Meshes
|
||||||
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
|
from pytorch3d.structures.volumes import Volumes
|
||||||
|
from pytorch3d.transforms.so3 import so3_exponential_map
|
||||||
|
|
||||||
|
|
||||||
|
DEBUG = False
|
||||||
|
if DEBUG:
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def init_cube_point_cloud(
|
||||||
|
batch_size: int = 10, n_points: int = 100000, rotate_y: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a random point cloud of `n_points` whose points of
|
||||||
|
which are sampled from faces of a 3D cube.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# create the cube mesh batch_size times
|
||||||
|
meshes = TestPointsToVolumes.init_cube_mesh(batch_size)
|
||||||
|
|
||||||
|
# generate point clouds by sampling points from the meshes
|
||||||
|
pcl = sample_points_from_meshes(meshes, num_samples=n_points, return_normals=False)
|
||||||
|
|
||||||
|
# colors of the cube sides
|
||||||
|
clrs = [
|
||||||
|
[1.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.0, 0.0],
|
||||||
|
[0.0, 1.0, 0.0],
|
||||||
|
[0.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
[1.0, 0.0, 1.0],
|
||||||
|
]
|
||||||
|
|
||||||
|
# init the color tensor "rgb"
|
||||||
|
rgb = torch.zeros_like(pcl)
|
||||||
|
|
||||||
|
# color each side of the cube with a constant color
|
||||||
|
clri = 0
|
||||||
|
for dim in (0, 1, 2):
|
||||||
|
for offs in (0.0, 1.0):
|
||||||
|
current_face_verts = (pcl[:, :, dim] - offs).abs() <= 1e-2
|
||||||
|
for bi in range(batch_size):
|
||||||
|
rgb[bi, current_face_verts[bi], :] = torch.tensor(clrs[clri]).type_as(
|
||||||
|
pcl
|
||||||
|
)
|
||||||
|
clri += 1
|
||||||
|
|
||||||
|
if rotate_y:
|
||||||
|
# uniformly spaced rotations around y axis
|
||||||
|
R = init_uniform_y_rotations(batch_size=batch_size)
|
||||||
|
# rotate the point clouds around y axis
|
||||||
|
pcl = torch.bmm(pcl - 0.5, R) + 0.5
|
||||||
|
|
||||||
|
return pcl, rgb
|
||||||
|
|
||||||
|
|
||||||
|
def init_volume_boundary_pointcloud(
|
||||||
|
batch_size: int,
|
||||||
|
volume_size: Tuple[int, int, int],
|
||||||
|
n_points: int,
|
||||||
|
interp_mode: str,
|
||||||
|
require_grad: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize a point cloud that closely follows a boundary of
|
||||||
|
a volume with a given size. The volume buffer is initialized as well.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# generate a 3D point cloud sampled from sides of a [0,1] cube
|
||||||
|
xyz, rgb = init_cube_point_cloud(batch_size, n_points=n_points, rotate_y=True)
|
||||||
|
|
||||||
|
# make volume_size tensor
|
||||||
|
volume_size_t = torch.tensor(volume_size, dtype=xyz.dtype, device=xyz.device)
|
||||||
|
|
||||||
|
if interp_mode == "trilinear":
|
||||||
|
# make the xyz locations fall on the boundary of the
|
||||||
|
# first/last two voxels along each spatial dimension of the
|
||||||
|
# volume - this properly checks the correctness of the
|
||||||
|
# trilinear interpolation scheme
|
||||||
|
xyz = (xyz - 0.5) * ((volume_size_t - 2) / (volume_size_t - 1))[[2, 1, 0]] + 0.5
|
||||||
|
|
||||||
|
# rescale the cube pointcloud to overlap with the volume sides
|
||||||
|
# of the volume
|
||||||
|
rel_scale = volume_size_t / volume_size[0]
|
||||||
|
xyz = xyz * rel_scale[[2, 1, 0]][None, None]
|
||||||
|
|
||||||
|
# enable grad accumulation for the differentiability check
|
||||||
|
xyz.requires_grad = require_grad
|
||||||
|
rgb.requires_grad = require_grad
|
||||||
|
|
||||||
|
# create the pointclouds structure
|
||||||
|
pointclouds = Pointclouds(xyz, features=rgb)
|
||||||
|
|
||||||
|
# set the volume translation so that the point cloud is centered
|
||||||
|
# around 0
|
||||||
|
volume_translation = -0.5 * rel_scale[[2, 1, 0]]
|
||||||
|
|
||||||
|
# set the voxel size to 1 / (volume_size-1)
|
||||||
|
volume_voxel_size = 1 / (volume_size[0] - 1.0)
|
||||||
|
|
||||||
|
# instantiate the volumes
|
||||||
|
initial_volumes = Volumes(
|
||||||
|
features=xyz.new_zeros(batch_size, 3, *volume_size),
|
||||||
|
densities=xyz.new_zeros(batch_size, 1, *volume_size),
|
||||||
|
volume_translation=volume_translation,
|
||||||
|
voxel_size=volume_voxel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pointclouds, initial_volumes
|
||||||
|
|
||||||
|
|
||||||
|
def init_uniform_y_rotations(batch_size: int = 10):
|
||||||
|
"""
|
||||||
|
Generate a batch of `batch_size` 3x3 rotation matrices around y-axis
|
||||||
|
whose angles are uniformly distributed between 0 and 2 pi.
|
||||||
|
"""
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
axis = torch.tensor([0.0, 1.0, 0.0], device=device, dtype=torch.float32)
|
||||||
|
angles = torch.linspace(0, 2.0 * np.pi, batch_size + 1, device=device)
|
||||||
|
angles = angles[:batch_size]
|
||||||
|
log_rots = axis[None, :] * angles[:, None]
|
||||||
|
R = so3_exponential_map(log_rots)
|
||||||
|
return R
|
||||||
|
|
||||||
|
|
||||||
|
class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
np.random.seed(42)
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_points_to_volumes(
|
||||||
|
batch_size: int,
|
||||||
|
volume_size: Tuple[int, int, int],
|
||||||
|
n_points: int,
|
||||||
|
interp_mode: str,
|
||||||
|
):
|
||||||
|
(pointclouds, initial_volumes) = init_volume_boundary_pointcloud(
|
||||||
|
batch_size=batch_size,
|
||||||
|
volume_size=volume_size,
|
||||||
|
n_points=n_points,
|
||||||
|
interp_mode=interp_mode,
|
||||||
|
require_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_points_to_volumes():
|
||||||
|
add_pointclouds_to_volumes(pointclouds, initial_volumes, mode=interp_mode)
|
||||||
|
|
||||||
|
return _add_points_to_volumes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def stack_4d_tensor_to_3d(arr):
|
||||||
|
n = arr.shape[0]
|
||||||
|
H = int(np.ceil(np.sqrt(n)))
|
||||||
|
W = int(np.ceil(n / H))
|
||||||
|
n_add = H * W - n
|
||||||
|
arr = torch.cat((arr, torch.zeros_like(arr[:1]).repeat(n_add, 1, 1, 1)))
|
||||||
|
rows = torch.chunk(arr, chunks=W, dim=0)
|
||||||
|
arr3d = torch.cat([torch.cat(list(row), dim=2) for row in rows], dim=1)
|
||||||
|
return arr3d
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_cube_mesh(batch_size: int = 10):
|
||||||
|
"""
|
||||||
|
Generate a batch of `batch_size` cube meshes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
verts, faces = [], []
|
||||||
|
|
||||||
|
for _ in range(batch_size):
|
||||||
|
v = torch.tensor(
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[1.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.0, 0.0],
|
||||||
|
[0.0, 1.0, 0.0],
|
||||||
|
[0.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
[1.0, 0.0, 1.0],
|
||||||
|
[0.0, 0.0, 1.0],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
verts.append(v)
|
||||||
|
faces.append(
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
[0, 2, 1],
|
||||||
|
[0, 3, 2],
|
||||||
|
[2, 3, 4],
|
||||||
|
[2, 4, 5],
|
||||||
|
[1, 2, 5],
|
||||||
|
[1, 5, 6],
|
||||||
|
[0, 7, 4],
|
||||||
|
[0, 4, 3],
|
||||||
|
[5, 4, 7],
|
||||||
|
[5, 7, 6],
|
||||||
|
[0, 6, 7],
|
||||||
|
[0, 1, 6],
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
faces = torch.stack(faces)
|
||||||
|
verts = torch.stack(verts)
|
||||||
|
|
||||||
|
simpleces = Meshes(verts=verts, faces=faces)
|
||||||
|
|
||||||
|
return simpleces
|
||||||
|
|
||||||
|
def test_from_point_cloud(self, interp_mode="trilinear"):
|
||||||
|
"""
|
||||||
|
Generates a volume from a random point cloud sampled from faces
|
||||||
|
of a 3D cube. Since each side of the cube is homogenously colored with
|
||||||
|
a different color, this should result in a volume with a
|
||||||
|
predefined homogenous color of the cells along its borders
|
||||||
|
and black interior. The test is run for both cube and non-cube shaped
|
||||||
|
volumes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# batch_size = 4 sides of the cube
|
||||||
|
batch_size = 4
|
||||||
|
|
||||||
|
for volume_size in ([25, 25, 25], [30, 25, 15]):
|
||||||
|
|
||||||
|
for interp_mode in ("trilinear", "nearest"):
|
||||||
|
|
||||||
|
(pointclouds, initial_volumes) = init_volume_boundary_pointcloud(
|
||||||
|
volume_size=volume_size,
|
||||||
|
n_points=int(1e5),
|
||||||
|
interp_mode=interp_mode,
|
||||||
|
batch_size=batch_size,
|
||||||
|
require_grad=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
volumes = add_pointclouds_to_volumes(
|
||||||
|
pointclouds, initial_volumes, mode=interp_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
V_color, V_density = volumes.features(), volumes.densities()
|
||||||
|
|
||||||
|
# expected colors of different cube sides
|
||||||
|
clr_sides = torch.tensor(
|
||||||
|
[
|
||||||
|
[[1.0, 1.0, 1.0], [1.0, 0.0, 1.0]],
|
||||||
|
[[1.0, 0.0, 0.0], [1.0, 1.0, 0.0]],
|
||||||
|
[[1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
|
[[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]],
|
||||||
|
],
|
||||||
|
dtype=V_color.dtype,
|
||||||
|
device=V_color.device,
|
||||||
|
)
|
||||||
|
clr_ambient = torch.tensor(
|
||||||
|
[0.0, 0.0, 0.0], dtype=V_color.dtype, device=V_color.device
|
||||||
|
)
|
||||||
|
clr_top_bot = torch.tensor(
|
||||||
|
[[0.0, 1.0, 0.0], [0.0, 1.0, 1.0]],
|
||||||
|
dtype=V_color.dtype,
|
||||||
|
device=V_color.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
outdir = tempfile.gettempdir() + "/test_points_to_volumes"
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
|
for slice_dim in (1, 2):
|
||||||
|
for vidx in range(V_color.shape[0]):
|
||||||
|
vim = V_color.detach()[vidx].split(1, dim=slice_dim)
|
||||||
|
vim = torch.stack([v.squeeze() for v in vim])
|
||||||
|
vim = TestPointsToVolumes.stack_4d_tensor_to_3d(vim.cpu())
|
||||||
|
im = Image.fromarray(
|
||||||
|
(vim.numpy() * 255.0)
|
||||||
|
.astype(np.uint8)
|
||||||
|
.transpose(1, 2, 0)
|
||||||
|
)
|
||||||
|
outfile = (
|
||||||
|
outdir
|
||||||
|
+ f"/rgb_{interp_mode}"
|
||||||
|
+ f"_{str(volume_size).replace(' ','')}"
|
||||||
|
+ f"_{vidx:003d}_sldim{slice_dim}.png"
|
||||||
|
)
|
||||||
|
im.save(outfile)
|
||||||
|
print("exported %s" % outfile)
|
||||||
|
|
||||||
|
# check the density V_density
|
||||||
|
# first binarize the density
|
||||||
|
V_density_bin = (V_density > 1e-4).type_as(V_density)
|
||||||
|
d_one = V_density.new_ones(1)
|
||||||
|
d_zero = V_density.new_zeros(1)
|
||||||
|
for vidx in range(V_color.shape[0]):
|
||||||
|
# the first/last depth-wise slice has to be filled with 1.0
|
||||||
|
self._check_volume_slice_color_density(
|
||||||
|
V_density_bin[vidx], 1, interp_mode, d_one, "first"
|
||||||
|
)
|
||||||
|
self._check_volume_slice_color_density(
|
||||||
|
V_density_bin[vidx], 1, interp_mode, d_one, "last"
|
||||||
|
)
|
||||||
|
# the middle depth-wise slices have to be empty
|
||||||
|
self._check_volume_slice_color_density(
|
||||||
|
V_density_bin[vidx], 1, interp_mode, d_zero, "middle"
|
||||||
|
)
|
||||||
|
# the top/bottom slices have to be filled with 1.0
|
||||||
|
self._check_volume_slice_color_density(
|
||||||
|
V_density_bin[vidx], 2, interp_mode, d_one, "first"
|
||||||
|
)
|
||||||
|
self._check_volume_slice_color_density(
|
||||||
|
V_density_bin[vidx], 2, interp_mode, d_one, "last"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check the colors
|
||||||
|
for vidx in range(V_color.shape[0]):
|
||||||
|
self._check_volume_slice_color_density(
|
||||||
|
V_color[vidx], 1, interp_mode, clr_sides[vidx][0], "first"
|
||||||
|
)
|
||||||
|
self._check_volume_slice_color_density(
|
||||||
|
V_color[vidx], 1, interp_mode, clr_sides[vidx][1], "last"
|
||||||
|
)
|
||||||
|
self._check_volume_slice_color_density(
|
||||||
|
V_color[vidx], 1, interp_mode, clr_ambient, "middle"
|
||||||
|
)
|
||||||
|
self._check_volume_slice_color_density(
|
||||||
|
V_color[vidx], 2, interp_mode, clr_top_bot[0], "first"
|
||||||
|
)
|
||||||
|
self._check_volume_slice_color_density(
|
||||||
|
V_color[vidx], 2, interp_mode, clr_top_bot[1], "last"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check differentiability
|
||||||
|
loss = V_color.mean() + V_density.mean()
|
||||||
|
loss.backward()
|
||||||
|
rgb = pointclouds.features_padded()
|
||||||
|
xyz = pointclouds.points_padded()
|
||||||
|
for field in (xyz, rgb):
|
||||||
|
if interp_mode == "nearest" and (field is xyz):
|
||||||
|
# this does not produce grads w.r.t. xyz
|
||||||
|
self.assertIsNone(field.grad)
|
||||||
|
else:
|
||||||
|
self.assertTrue(field.grad.data.isfinite().all())
|
||||||
|
|
||||||
|
def _check_volume_slice_color_density(
|
||||||
|
self, V, split_dim, interp_mode, clr_gt, slice_type, border=3
|
||||||
|
):
|
||||||
|
# decompose the volume to individual slices along split_dim
|
||||||
|
vim = V.detach().split(1, dim=split_dim)
|
||||||
|
vim = torch.stack([v.squeeze(split_dim) for v in vim])
|
||||||
|
|
||||||
|
# determine which slices should be compared to clr_gt based on
|
||||||
|
# the 'slice_type' input
|
||||||
|
if slice_type == "first":
|
||||||
|
slice_dims = (0, 1) if interp_mode == "trilinear" else (0,)
|
||||||
|
elif slice_type == "last":
|
||||||
|
slice_dims = (-1, -2) if interp_mode == "trilinear" else (-1,)
|
||||||
|
elif slice_type == "middle":
|
||||||
|
internal_border = 2 if interp_mode == "trilinear" else 1
|
||||||
|
slice_dims = torch.arange(internal_border, vim.shape[0] - internal_border)
|
||||||
|
else:
|
||||||
|
raise ValueError(slice_type)
|
||||||
|
|
||||||
|
# compute the average error within each slice
|
||||||
|
clr_diff = (
|
||||||
|
vim[slice_dims, :, border:-border, border:-border]
|
||||||
|
- clr_gt[None, :, None, None]
|
||||||
|
)
|
||||||
|
clr_diff = clr_diff.abs().mean(dim=(2, 3)).view(-1)
|
||||||
|
|
||||||
|
# check that all per-slice avg errors vanish
|
||||||
|
self.assertClose(clr_diff, torch.zeros_like(clr_diff), atol=1e-2)
|
Loading…
x
Reference in New Issue
Block a user