Point clouds to volumes

Summary:
Conversion from point clouds to volumes

```
Benchmark                                                        Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
ADD_POINTS_TO_VOLUMES_10_trilinear_[25, 25, 25]_1000                 43219           44067             12
ADD_POINTS_TO_VOLUMES_10_trilinear_[25, 25, 25]_10000                43274           45313             12
ADD_POINTS_TO_VOLUMES_10_trilinear_[25, 25, 25]_100000               46281           47100             11
ADD_POINTS_TO_VOLUMES_10_trilinear_[101, 111, 121]_1000              51224           51912             10
ADD_POINTS_TO_VOLUMES_10_trilinear_[101, 111, 121]_10000             52092           54487             10
ADD_POINTS_TO_VOLUMES_10_trilinear_[101, 111, 121]_100000            59262           60514              9
ADD_POINTS_TO_VOLUMES_10_nearest_[25, 25, 25]_1000                   15998           17237             32
ADD_POINTS_TO_VOLUMES_10_nearest_[25, 25, 25]_10000                  15964           16994             32
ADD_POINTS_TO_VOLUMES_10_nearest_[25, 25, 25]_100000                 16881           19286             30
ADD_POINTS_TO_VOLUMES_10_nearest_[101, 111, 121]_1000                19150           25277             27
ADD_POINTS_TO_VOLUMES_10_nearest_[101, 111, 121]_10000               18746           19999             27
ADD_POINTS_TO_VOLUMES_10_nearest_[101, 111, 121]_100000              22321           24568             23
ADD_POINTS_TO_VOLUMES_100_trilinear_[25, 25, 25]_1000                49693           50288             11
ADD_POINTS_TO_VOLUMES_100_trilinear_[25, 25, 25]_10000               51429           52449             10
ADD_POINTS_TO_VOLUMES_100_trilinear_[25, 25, 25]_100000             237076          237377              3
ADD_POINTS_TO_VOLUMES_100_trilinear_[101, 111, 121]_1000             81875           82597              7
ADD_POINTS_TO_VOLUMES_100_trilinear_[101, 111, 121]_10000           106671          107045              5
ADD_POINTS_TO_VOLUMES_100_trilinear_[101, 111, 121]_100000          483740          484607              2
ADD_POINTS_TO_VOLUMES_100_nearest_[25, 25, 25]_1000                  16667           18143             31
ADD_POINTS_TO_VOLUMES_100_nearest_[25, 25, 25]_10000                 17682           18922             29
ADD_POINTS_TO_VOLUMES_100_nearest_[25, 25, 25]_100000                65463           67116              8
ADD_POINTS_TO_VOLUMES_100_nearest_[101, 111, 121]_1000               48058           48826             11
ADD_POINTS_TO_VOLUMES_100_nearest_[101, 111, 121]_10000              53529           53998             10
ADD_POINTS_TO_VOLUMES_100_nearest_[101, 111, 121]_100000            123684          123901              5
--------------------------------------------------------------------------------
```

Output with `DEBUG=True`
{F338561209}

Reviewed By: nikhilaravi

Differential Revision: D22017500

fbshipit-source-id: ed3e8ed13940c593841d93211623dd533974012f
This commit is contained in:
David Novotny 2021-01-05 03:37:38 -08:00 committed by Facebook GitHub Bot
parent 03ee1dbf82
commit aa9bcaf04c
4 changed files with 904 additions and 0 deletions

View File

@ -14,6 +14,10 @@ from .points_normals import (
estimate_pointcloud_local_coord_frames,
estimate_pointcloud_normals,
)
from .points_to_volumes import (
add_pointclouds_to_volumes,
add_points_features_to_volume_densities_features,
)
from .sample_points_from_meshes import sample_points_from_meshes
from .subdivide_meshes import SubdivideMeshes
from .utils import (

View File

@ -0,0 +1,491 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import TYPE_CHECKING, Optional, Tuple
import torch
if TYPE_CHECKING:
from ..structures import Pointclouds, Volumes
def add_pointclouds_to_volumes(
pointclouds: "Pointclouds",
initial_volumes: "Volumes",
mode: str = "trilinear",
min_weight: float = 1e-4,
) -> "Volumes":
"""
Add a batch of point clouds represented with a `Pointclouds` structure
`pointclouds` to a batch of existing volumes represented with a
`Volumes` structure `initial_volumes`.
More specifically, the method casts a set of weighted votes (the weights are
determined based on `mode="trilinear"|"nearest"`) into the pre-initialized
`features` and `densities` fields of `initial_volumes`.
The method returns an updated `Volumes` object that contains a copy
of `initial_volumes` with its `features` and `densities` updated with the
result of the pointcloud addition.
Example:
```
# init a random point cloud
pointclouds = Pointclouds(
points=torch.randn(4, 100, 3), features=torch.rand(4, 100, 5)
)
# init an empty volume centered around [0.5, 0.5, 0.5] in world coordinates
# with a voxel size of 1.0.
initial_volumes = Volumes(
features = torch.zeros(4, 5, 25, 25, 25),
densities = torch.zeros(4, 1, 25, 25, 25),
volume_translation = [-0.5, -0.5, -0.5],
voxel_size = 1.0,
)
# add the pointcloud to the 'initial_volumes' buffer using
# trilinear splatting
updated_volumes = add_pointclouds_to_volumes(
pointclouds=pointclouds,
initial_volumes=initial_volumes,
mode="trilinear",
)
```
Args:
pointclouds: Batch of 3D pointclouds represented with a `Pointclouds`
structure. Note that `pointclouds.features` have to be defined.
initial_volumes: Batch of initial `Volumes` with pre-initialized 1-dimensional
densities which contain non-negative numbers corresponding to the
opaqueness of each voxel (the higher, the less transparent).
mode: The mode of the conversion of individual points into the volume.
Set either to `nearest` or `trilinear`:
`nearest`: Each 3D point is first rounded to the volumetric
lattice. Each voxel is then labeled with the average
over features that fall into the given voxel.
The gradients of nearest neighbor conversion w.r.t. the
3D locations of the points in `pointclouds` are *not* defined.
`trilinear`: Each 3D point casts 8 weighted votes to the 8-neighborhood
of its floating point coordinate. The weights are
determined using a trilinear interpolation scheme.
Trilinear splatting is fully differentiable w.r.t. all input arguments.
min_weight: A scalar controlling the lowest possible total per-voxel
weight used to normalize the features accumulated in a voxel.
Only active for `mode==trilinear`.
Returns:
updated_volumes: Output `Volumes` structure containing the conversion result.
"""
if len(initial_volumes) != len(pointclouds):
raise ValueError(
"'initial_volumes' and 'pointclouds' have to have the same batch size."
)
# obtain the features and densities
pcl_feats = pointclouds.features_padded()
pcl_3d = pointclouds.points_padded()
if pcl_feats is None:
raise ValueError("'pointclouds' have to have their 'features' defined.")
# obtain the conversion mask
n_per_pcl = pointclouds.num_points_per_cloud().type_as(pcl_feats)
mask = torch.arange(n_per_pcl.max(), dtype=pcl_feats.dtype, device=pcl_feats.device)
mask = (mask[None, :] < n_per_pcl[:, None]).type_as(mask)
# convert to the coord frame of the volume
pcl_3d_local = initial_volumes.world_to_local_coords(pcl_3d)
features_new, densities_new = add_points_features_to_volume_densities_features(
points_3d=pcl_3d_local,
points_features=pcl_feats,
volume_features=initial_volumes.features(),
volume_densities=initial_volumes.densities(),
min_weight=min_weight,
grid_sizes=initial_volumes.get_grid_sizes(),
mask=mask,
mode=mode,
)
return initial_volumes.update_padded(
new_densities=densities_new, new_features=features_new
)
def add_points_features_to_volume_densities_features(
points_3d: torch.Tensor,
points_features: torch.Tensor,
volume_densities: torch.Tensor,
volume_features: Optional[torch.Tensor],
mode: str = "trilinear",
min_weight: float = 1e-4,
mask: Optional[torch.Tensor] = None,
grid_sizes: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert a batch of point clouds represented with tensors of per-point
3d coordinates and their features to a batch of volumes represented
with tensors of densities and features.
Args:
points_3d: Batch of 3D point cloud coordinates of shape
`(minibatch, N, 3)` where N is the number of points
in each point cloud. Coordinates have to be specified in the
local volume coordinates (ranging in [-1, 1]).
points_features: Features of shape `(minibatch, N, feature_dim)` corresponding
to the points of the input point clouds `pointcloud`.
volume_densities: Batch of input feature volume densities of shape
`(minibatch, 1, D, H, W)`. Each voxel should
contain a non-negative number corresponding to its
opaqueness (the higher, the less transparent).
volume_features: Batch of input feature volumes of shape
`(minibatch, feature_dim, D, H, W)`
If set to `None`, the `volume_features` will be automatically
instantiatied with a correct size and filled with 0s.
mode: The mode of the conversion of individual points into the volume.
Set either to `nearest` or `trilinear`:
`nearest`: Each 3D point is first rounded to the volumetric
lattice. Each voxel is then labeled with the average
over features that fall into the given voxel.
The gradients of nearest neighbor rounding w.r.t. the
input point locations `points_3d` are *not* defined.
`trilinear`: Each 3D point casts 8 weighted votes to the 8-neighborhood
of its floating point coordinate. The weights are
determined using a trilinear interpolation scheme.
Trilinear splatting is fully differentiable w.r.t. all input arguments.
mask: A binary mask of shape `(minibatch, N)` determining which 3D points
are going to be converted to the resulting volume.
Set to `None` if all points are valid.
min_weight: A scalar controlling the lowest possible total per-voxel
weight used to normalize the features accumulated in a voxel.
Only active for `mode==trilinear`.
Returns:
volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)`
volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)`
containing the total amount of votes cast to each of the voxels.
"""
# number of points in the point cloud, its dim and batch size
ba, n_points, feature_dim = points_features.shape
ba_volume, density_dim = volume_densities.shape[:2]
if density_dim != 1:
raise ValueError("Only one-dimensional densities are allowed.")
# init the volumetric grid sizes if uninitialized
if grid_sizes is None:
grid_sizes = torch.LongTensor(list(volume_densities.shape[2:])).to(
volume_densities
)
# flatten densities and features
v_shape = volume_densities.shape[2:]
volume_densities_flatten = volume_densities.view(ba, -1, 1)
n_voxels = volume_densities_flatten.shape[1]
if volume_features is None:
# initialize features if not passed in
volume_features_flatten = volume_densities.new_zeros(ba, feature_dim, n_voxels)
else:
# otherwise just flatten
volume_features_flatten = volume_features.view(ba, feature_dim, n_voxels)
if mode == "trilinear": # do the splatting (trilinear interp)
volume_features, volume_densities = splat_points_to_volumes(
points_3d,
points_features,
volume_densities_flatten,
volume_features_flatten,
grid_sizes,
mask=mask,
min_weight=min_weight,
)
elif mode == "nearest": # nearest neighbor interp
volume_features, volume_densities = round_points_to_volumes(
points_3d,
points_features,
volume_densities_flatten,
volume_features_flatten,
grid_sizes,
mask=mask,
)
else:
raise ValueError('No such interpolation mode "%s"' % mode)
# reshape into the volume shape
volume_features = volume_features.view(ba, feature_dim, *v_shape)
volume_densities = volume_densities.view(ba, 1, *v_shape)
return volume_features, volume_densities
def _check_points_to_volumes_inputs(
points_3d: torch.Tensor,
points_features: torch.Tensor,
volume_densities: torch.Tensor,
volume_features: torch.Tensor,
grid_sizes: torch.LongTensor,
mask: Optional[torch.Tensor] = None,
):
max_grid_size = grid_sizes.max(dim=0).values
if torch.prod(max_grid_size) > volume_densities.shape[1]:
raise ValueError(
"One of the grid sizes corresponds to a larger number"
+ " of elements than the number of elements in volume_densities."
)
_, n_voxels, density_dim = volume_densities.shape
if density_dim != 1:
raise ValueError("Only one-dimensional densities are allowed.")
ba, n_points, feature_dim = points_features.shape
if volume_features.shape[1] != feature_dim:
raise ValueError(
"volume_features have a different number of channels"
+ " than points_features."
)
if volume_features.shape[2] != n_voxels:
raise ValueError(
"volume_features have a different number of elements"
+ " than volume_densities."
)
def splat_points_to_volumes(
points_3d: torch.Tensor,
points_features: torch.Tensor,
volume_densities: torch.Tensor,
volume_features: torch.Tensor,
grid_sizes: torch.LongTensor,
min_weight: float = 1e-4,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert a batch of point clouds to a batch of volumes using trilinear
splatting into a volume.
Args:
points_3d: Batch of 3D point cloud coordinates of shape
`(minibatch, N, 3)` where N is the number of points
in each point cloud. Coordinates have to be specified in the
local volume coordinates (ranging in [-1, 1]).
points_features: Features of shape `(minibatch, N, feature_dim)`
corresponding to the points of the input point cloud `points_3d`.
volume_features: Batch of input *flattened* feature volumes
of shape `(minibatch, feature_dim, N_voxels)`
volume_densities: Batch of input *flattened* feature volume densities
of shape `(minibatch, 1, N_voxels)`. Each voxel should
contain a non-negative number corresponding to its
opaqueness (the higher, the less transparent).
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
spatial resolutions of each of the the non-flattened `volumes` tensors.
Note that the following has to hold:
`torch.prod(grid_sizes, dim=1)==N_voxels`
mask: A binary mask of shape `(minibatch, N)` determining which 3D points
are going to be converted to the resulting volume.
Set to `None` if all points are valid.
Returns:
volume_features: Output volume of shape `(minibatch, D, N_voxels)`.
volume_densities: Occupancy volume of shape `(minibatch, 1, N_voxels)`
containing the total amount of votes cast to each of the voxels.
"""
_check_points_to_volumes_inputs(
points_3d,
points_features,
volume_densities,
volume_features,
grid_sizes,
mask=mask,
)
_, n_voxels, density_dim = volume_densities.shape
ba, n_points, feature_dim = points_features.shape
# minibatch x n_points x feature_dim -> minibatch x feature_dim x n_points
points_features = points_features.permute(0, 2, 1).contiguous()
# XYZ = the upper-left volume index of the 8-neigborhood of every point
# grid_sizes is of the form (minibatch, depth-height-width)
grid_sizes_xyz = grid_sizes[:, [2, 1, 0]]
# Convert from points_3d in the range [-1, 1] to
# indices in the volume grid in the range [0, grid_sizes_xyz-1]
points_3d_indices = ((points_3d + 1) * 0.5) * (
grid_sizes_xyz[:, None].type_as(points_3d) - 1
)
XYZ = points_3d_indices.floor().long()
rXYZ = points_3d_indices - XYZ.type_as(points_3d) # remainder of floor
# split into separate coordinate vectors
X, Y, Z = XYZ.split(1, dim=2)
# rX = remainder after floor = 1-"the weight of each vote into
# the X coordinate of the 8-neighborhood"
rX, rY, rZ = rXYZ.split(1, dim=2)
# get random indices for the purpose of adding out-of-bounds values
rand_idx = X.new_zeros(X.shape).random_(0, n_voxels)
# iterate over the x, y, z indices of the 8-neighborhood (xdiff, ydiff, zdiff)
for xdiff in (0, 1):
X_ = X + xdiff
wX = (1 - xdiff) + (2 * xdiff - 1) * rX
for ydiff in (0, 1):
Y_ = Y + ydiff
wY = (1 - ydiff) + (2 * ydiff - 1) * rY
for zdiff in (0, 1):
Z_ = Z + zdiff
wZ = (1 - zdiff) + (2 * zdiff - 1) * rZ
# weight of each vote into the given cell of 8-neighborhood
w = wX * wY * wZ
# valid - binary indicators of votes that fall into the volume
valid = (
(0 <= X_)
* (X_ < grid_sizes_xyz[:, None, 0:1])
* (0 <= Y_)
* (Y_ < grid_sizes_xyz[:, None, 1:2])
* (0 <= Z_)
* (Z_ < grid_sizes_xyz[:, None, 2:3])
).long()
# linearized indices into the volume
idx = (Z_ * grid_sizes[:, None, 1:2] + Y_) * grid_sizes[
:, None, 2:3
] + X_
# out-of-bounds features added to a random voxel idx with weight=0.
idx_valid = idx * valid + rand_idx * (1 - valid)
w_valid = w * valid.type_as(w)
if mask is not None:
w_valid = w_valid * mask.type_as(w)[:, :, None]
# scatter add casts the votes into the weight accumulator
# and the feature accumulator
volume_densities.scatter_add_(1, idx_valid, w_valid)
# reshape idx_valid -> (minibatch, feature_dim, n_points)
idx_valid = idx_valid.view(ba, 1, n_points).expand_as(points_features)
w_valid = w_valid.view(ba, 1, n_points)
# volume_features of shape (minibatch, feature_dim, n_voxels)
volume_features.scatter_add_(2, idx_valid, w_valid * points_features)
# divide each feature by the total weight of the votes
volume_features = volume_features / volume_densities.view(ba, 1, n_voxels).clamp(
min_weight
)
return volume_features, volume_densities
def round_points_to_volumes(
points_3d: torch.Tensor,
points_features: torch.Tensor,
volume_densities: torch.Tensor,
volume_features: torch.Tensor,
grid_sizes: torch.LongTensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert a batch of point clouds to a batch of volumes using rounding to the
nearest integer coordinate of the volume. Features that fall into the same
voxel are averaged.
Args:
points_3d: Batch of 3D point cloud coordinates of shape
`(minibatch, N, 3)` where N is the number of points
in each point cloud. Coordinates have to be specified in the
local volume coordinates (ranging in [-1, 1]).
points_features: Features of shape `(minibatch, N, feature_dim)`
corresponding to the points of the input point cloud `points_3d`.
volume_features: Batch of input *flattened* feature volumes
of shape `(minibatch, feature_dim, N_voxels)`
volume_densities: Batch of input *flattened* feature volume densities
of shape `(minibatch, 1, N_voxels)`. Each voxel should
contain a non-negative number corresponding to its
opaqueness (the higher, the less transparent).
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
spatial resolutions of each of the the non-flattened `volumes` tensors.
Note that the following has to hold:
`torch.prod(grid_sizes, dim=1)==N_voxels`
mask: A binary mask of shape `(minibatch, N)` determining which 3D points
are going to be converted to the resulting volume.
Set to `None` if all points are valid.
Returns:
volume_features: Output volume of shape `(minibatch, D, N_voxels)`.
volume_densities: Occupancy volume of shape `(minibatch, 1, N_voxels)`
containing the total amount of votes cast to each of the voxels.
"""
_check_points_to_volumes_inputs(
points_3d,
points_features,
volume_densities,
volume_features,
grid_sizes,
mask=mask,
)
_, n_voxels, density_dim = volume_densities.shape
ba, n_points, feature_dim = points_features.shape
# minibatch x n_points x feature_dim-> minibatch x feature_dim x n_points
points_features = points_features.permute(0, 2, 1).contiguous()
# round the coordinates to nearest integer
# grid_sizes is of the form (minibatch, depth-height-width)
grid_sizes_xyz = grid_sizes[:, [2, 1, 0]]
XYZ = ((points_3d.detach() + 1) * 0.5) * (
grid_sizes_xyz[:, None].type_as(points_3d) - 1
)
XYZ = torch.round(XYZ).long()
# split into separate coordinate vectors
X, Y, Z = XYZ.split(1, dim=2)
# get random indices for the purpose of adding out-of-bounds values
rand_idx = X.new_zeros(X.shape).random_(0, n_voxels)
# valid - binary indicators of votes that fall into the volume
grid_sizes = grid_sizes.type_as(XYZ)
valid = (
(0 <= X)
* (X < grid_sizes_xyz[:, None, 0:1])
* (0 <= Y)
* (Y < grid_sizes_xyz[:, None, 1:2])
* (0 <= Z)
* (Z < grid_sizes_xyz[:, None, 2:3])
).long()
# get random indices for the purpose of adding out-of-bounds values
rand_idx = valid.new_zeros(X.shape).random_(0, n_voxels)
# linearized indices into the volume
idx = (Z * grid_sizes[:, None, 1:2] + Y) * grid_sizes[:, None, 2:3] + X
# out-of-bounds features added to a random voxel idx with weight=0.
idx_valid = idx * valid + rand_idx * (1 - valid)
w_valid = valid.type_as(volume_features)
# scatter add casts the votes into the weight accumulator
# and the feature accumulator
volume_densities.scatter_add_(1, idx_valid, w_valid)
# reshape idx_valid -> (minibatch, feature_dim, n_points)
idx_valid = idx_valid.view(ba, 1, n_points).expand_as(points_features)
w_valid = w_valid.view(ba, 1, n_points)
# volume_features of shape (minibatch, feature_dim, n_voxels)
volume_features.scatter_add_(2, idx_valid, w_valid * points_features)
# divide each feature by the total weight of the votes
volume_features = volume_features / volume_densities.view(ba, 1, n_voxels).clamp(
1.0
)
return volume_features, volume_densities

View File

@ -0,0 +1,24 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
from fvcore.common.benchmark import benchmark
from test_points_to_volumes import TestPointsToVolumes
def bm_points_to_volumes() -> None:
case_grid = {
"batch_size": [10, 100],
"interp_mode": ["trilinear", "nearest"],
"volume_size": [[25, 25, 25], [101, 111, 121]],
"n_points": [1000, 10000, 100000],
}
test_cases = itertools.product(*case_grid.values())
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
benchmark(
TestPointsToVolumes.add_points_to_volumes,
"ADD_POINTS_TO_VOLUMES",
kwargs_list,
warmup_iters=1,
)

View File

@ -0,0 +1,385 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from typing import Tuple
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import add_pointclouds_to_volumes
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes
from pytorch3d.structures.pointclouds import Pointclouds
from pytorch3d.structures.volumes import Volumes
from pytorch3d.transforms.so3 import so3_exponential_map
DEBUG = False
if DEBUG:
import os
import tempfile
from PIL import Image
def init_cube_point_cloud(
batch_size: int = 10, n_points: int = 100000, rotate_y: bool = True
):
"""
Generate a random point cloud of `n_points` whose points of
which are sampled from faces of a 3D cube.
"""
# create the cube mesh batch_size times
meshes = TestPointsToVolumes.init_cube_mesh(batch_size)
# generate point clouds by sampling points from the meshes
pcl = sample_points_from_meshes(meshes, num_samples=n_points, return_normals=False)
# colors of the cube sides
clrs = [
[1.0, 0.0, 0.0],
[1.0, 1.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, 0.0, 1.0],
]
# init the color tensor "rgb"
rgb = torch.zeros_like(pcl)
# color each side of the cube with a constant color
clri = 0
for dim in (0, 1, 2):
for offs in (0.0, 1.0):
current_face_verts = (pcl[:, :, dim] - offs).abs() <= 1e-2
for bi in range(batch_size):
rgb[bi, current_face_verts[bi], :] = torch.tensor(clrs[clri]).type_as(
pcl
)
clri += 1
if rotate_y:
# uniformly spaced rotations around y axis
R = init_uniform_y_rotations(batch_size=batch_size)
# rotate the point clouds around y axis
pcl = torch.bmm(pcl - 0.5, R) + 0.5
return pcl, rgb
def init_volume_boundary_pointcloud(
batch_size: int,
volume_size: Tuple[int, int, int],
n_points: int,
interp_mode: str,
require_grad: bool = False,
):
"""
Initialize a point cloud that closely follows a boundary of
a volume with a given size. The volume buffer is initialized as well.
"""
# generate a 3D point cloud sampled from sides of a [0,1] cube
xyz, rgb = init_cube_point_cloud(batch_size, n_points=n_points, rotate_y=True)
# make volume_size tensor
volume_size_t = torch.tensor(volume_size, dtype=xyz.dtype, device=xyz.device)
if interp_mode == "trilinear":
# make the xyz locations fall on the boundary of the
# first/last two voxels along each spatial dimension of the
# volume - this properly checks the correctness of the
# trilinear interpolation scheme
xyz = (xyz - 0.5) * ((volume_size_t - 2) / (volume_size_t - 1))[[2, 1, 0]] + 0.5
# rescale the cube pointcloud to overlap with the volume sides
# of the volume
rel_scale = volume_size_t / volume_size[0]
xyz = xyz * rel_scale[[2, 1, 0]][None, None]
# enable grad accumulation for the differentiability check
xyz.requires_grad = require_grad
rgb.requires_grad = require_grad
# create the pointclouds structure
pointclouds = Pointclouds(xyz, features=rgb)
# set the volume translation so that the point cloud is centered
# around 0
volume_translation = -0.5 * rel_scale[[2, 1, 0]]
# set the voxel size to 1 / (volume_size-1)
volume_voxel_size = 1 / (volume_size[0] - 1.0)
# instantiate the volumes
initial_volumes = Volumes(
features=xyz.new_zeros(batch_size, 3, *volume_size),
densities=xyz.new_zeros(batch_size, 1, *volume_size),
volume_translation=volume_translation,
voxel_size=volume_voxel_size,
)
return pointclouds, initial_volumes
def init_uniform_y_rotations(batch_size: int = 10):
"""
Generate a batch of `batch_size` 3x3 rotation matrices around y-axis
whose angles are uniformly distributed between 0 and 2 pi.
"""
device = torch.device("cuda:0")
axis = torch.tensor([0.0, 1.0, 0.0], device=device, dtype=torch.float32)
angles = torch.linspace(0, 2.0 * np.pi, batch_size + 1, device=device)
angles = angles[:batch_size]
log_rots = axis[None, :] * angles[:, None]
R = so3_exponential_map(log_rots)
return R
class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
np.random.seed(42)
torch.manual_seed(42)
@staticmethod
def add_points_to_volumes(
batch_size: int,
volume_size: Tuple[int, int, int],
n_points: int,
interp_mode: str,
):
(pointclouds, initial_volumes) = init_volume_boundary_pointcloud(
batch_size=batch_size,
volume_size=volume_size,
n_points=n_points,
interp_mode=interp_mode,
require_grad=False,
)
def _add_points_to_volumes():
add_pointclouds_to_volumes(pointclouds, initial_volumes, mode=interp_mode)
return _add_points_to_volumes
@staticmethod
def stack_4d_tensor_to_3d(arr):
n = arr.shape[0]
H = int(np.ceil(np.sqrt(n)))
W = int(np.ceil(n / H))
n_add = H * W - n
arr = torch.cat((arr, torch.zeros_like(arr[:1]).repeat(n_add, 1, 1, 1)))
rows = torch.chunk(arr, chunks=W, dim=0)
arr3d = torch.cat([torch.cat(list(row), dim=2) for row in rows], dim=1)
return arr3d
@staticmethod
def init_cube_mesh(batch_size: int = 10):
"""
Generate a batch of `batch_size` cube meshes.
"""
device = torch.device("cuda:0")
verts, faces = [], []
for _ in range(batch_size):
v = torch.tensor(
[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 1.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, 0.0, 1.0],
[0.0, 0.0, 1.0],
],
dtype=torch.float32,
device=device,
)
verts.append(v)
faces.append(
torch.tensor(
[
[0, 2, 1],
[0, 3, 2],
[2, 3, 4],
[2, 4, 5],
[1, 2, 5],
[1, 5, 6],
[0, 7, 4],
[0, 4, 3],
[5, 4, 7],
[5, 7, 6],
[0, 6, 7],
[0, 1, 6],
],
dtype=torch.int64,
device=device,
)
)
faces = torch.stack(faces)
verts = torch.stack(verts)
simpleces = Meshes(verts=verts, faces=faces)
return simpleces
def test_from_point_cloud(self, interp_mode="trilinear"):
"""
Generates a volume from a random point cloud sampled from faces
of a 3D cube. Since each side of the cube is homogenously colored with
a different color, this should result in a volume with a
predefined homogenous color of the cells along its borders
and black interior. The test is run for both cube and non-cube shaped
volumes.
"""
# batch_size = 4 sides of the cube
batch_size = 4
for volume_size in ([25, 25, 25], [30, 25, 15]):
for interp_mode in ("trilinear", "nearest"):
(pointclouds, initial_volumes) = init_volume_boundary_pointcloud(
volume_size=volume_size,
n_points=int(1e5),
interp_mode=interp_mode,
batch_size=batch_size,
require_grad=True,
)
volumes = add_pointclouds_to_volumes(
pointclouds, initial_volumes, mode=interp_mode
)
V_color, V_density = volumes.features(), volumes.densities()
# expected colors of different cube sides
clr_sides = torch.tensor(
[
[[1.0, 1.0, 1.0], [1.0, 0.0, 1.0]],
[[1.0, 0.0, 0.0], [1.0, 1.0, 0.0]],
[[1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
[[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]],
],
dtype=V_color.dtype,
device=V_color.device,
)
clr_ambient = torch.tensor(
[0.0, 0.0, 0.0], dtype=V_color.dtype, device=V_color.device
)
clr_top_bot = torch.tensor(
[[0.0, 1.0, 0.0], [0.0, 1.0, 1.0]],
dtype=V_color.dtype,
device=V_color.device,
)
if DEBUG:
outdir = tempfile.gettempdir() + "/test_points_to_volumes"
os.makedirs(outdir, exist_ok=True)
for slice_dim in (1, 2):
for vidx in range(V_color.shape[0]):
vim = V_color.detach()[vidx].split(1, dim=slice_dim)
vim = torch.stack([v.squeeze() for v in vim])
vim = TestPointsToVolumes.stack_4d_tensor_to_3d(vim.cpu())
im = Image.fromarray(
(vim.numpy() * 255.0)
.astype(np.uint8)
.transpose(1, 2, 0)
)
outfile = (
outdir
+ f"/rgb_{interp_mode}"
+ f"_{str(volume_size).replace(' ','')}"
+ f"_{vidx:003d}_sldim{slice_dim}.png"
)
im.save(outfile)
print("exported %s" % outfile)
# check the density V_density
# first binarize the density
V_density_bin = (V_density > 1e-4).type_as(V_density)
d_one = V_density.new_ones(1)
d_zero = V_density.new_zeros(1)
for vidx in range(V_color.shape[0]):
# the first/last depth-wise slice has to be filled with 1.0
self._check_volume_slice_color_density(
V_density_bin[vidx], 1, interp_mode, d_one, "first"
)
self._check_volume_slice_color_density(
V_density_bin[vidx], 1, interp_mode, d_one, "last"
)
# the middle depth-wise slices have to be empty
self._check_volume_slice_color_density(
V_density_bin[vidx], 1, interp_mode, d_zero, "middle"
)
# the top/bottom slices have to be filled with 1.0
self._check_volume_slice_color_density(
V_density_bin[vidx], 2, interp_mode, d_one, "first"
)
self._check_volume_slice_color_density(
V_density_bin[vidx], 2, interp_mode, d_one, "last"
)
# check the colors
for vidx in range(V_color.shape[0]):
self._check_volume_slice_color_density(
V_color[vidx], 1, interp_mode, clr_sides[vidx][0], "first"
)
self._check_volume_slice_color_density(
V_color[vidx], 1, interp_mode, clr_sides[vidx][1], "last"
)
self._check_volume_slice_color_density(
V_color[vidx], 1, interp_mode, clr_ambient, "middle"
)
self._check_volume_slice_color_density(
V_color[vidx], 2, interp_mode, clr_top_bot[0], "first"
)
self._check_volume_slice_color_density(
V_color[vidx], 2, interp_mode, clr_top_bot[1], "last"
)
# check differentiability
loss = V_color.mean() + V_density.mean()
loss.backward()
rgb = pointclouds.features_padded()
xyz = pointclouds.points_padded()
for field in (xyz, rgb):
if interp_mode == "nearest" and (field is xyz):
# this does not produce grads w.r.t. xyz
self.assertIsNone(field.grad)
else:
self.assertTrue(field.grad.data.isfinite().all())
def _check_volume_slice_color_density(
self, V, split_dim, interp_mode, clr_gt, slice_type, border=3
):
# decompose the volume to individual slices along split_dim
vim = V.detach().split(1, dim=split_dim)
vim = torch.stack([v.squeeze(split_dim) for v in vim])
# determine which slices should be compared to clr_gt based on
# the 'slice_type' input
if slice_type == "first":
slice_dims = (0, 1) if interp_mode == "trilinear" else (0,)
elif slice_type == "last":
slice_dims = (-1, -2) if interp_mode == "trilinear" else (-1,)
elif slice_type == "middle":
internal_border = 2 if interp_mode == "trilinear" else 1
slice_dims = torch.arange(internal_border, vim.shape[0] - internal_border)
else:
raise ValueError(slice_type)
# compute the average error within each slice
clr_diff = (
vim[slice_dims, :, border:-border, border:-border]
- clr_gt[None, :, None, None]
)
clr_diff = clr_diff.abs().mean(dim=(2, 3)).view(-1)
# check that all per-slice avg errors vanish
self.assertClose(clr_diff, torch.zeros_like(clr_diff), atol=1e-2)