voxel_grid_implicit_function scaffold fixes

Summary: Fix indexing of directions after filtering of points by scaffold.

Reviewed By: shapovalov

Differential Revision: D40853482

fbshipit-source-id: 9cfdb981e97cb82edcd27632c5848537ed2c6837
This commit is contained in:
Jeremy Reizenstein 2022-11-03 05:46:31 -07:00 committed by Facebook GitHub Bot
parent e4a3298149
commit a1f2ded58a
2 changed files with 44 additions and 30 deletions

View File

@ -7,7 +7,7 @@
import math
import warnings
from dataclasses import fields
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple
import torch
@ -118,11 +118,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
the calculation.)
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
voxel grid which stores scaffold
scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than
scaffold_empty_space_threshold (float): if `self._get_density` evaluates to less than
this it will be considered as empty space and the scaffold at that point would
evaluate as empty space.
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
at the same time. To calculate the scaffold we need to query `get_density()` at
at the same time. To calculate the scaffold we need to query `_get_density()` at
every voxel, this calculation can be split into scaffold depth number of xy plane
calculations if you want the lowest memory usage, one calculation to calculate the
whole scaffold, but with higher memory footprint or any other number of planes.
@ -242,14 +242,16 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
points = ray_bundle_to_ray_points(ray_bundle)
directions = ray_bundle.directions.reshape(-1, 3)
input_shape = points.shape
num_points_per_ray = input_shape[-2]
points = points.view(-1, 3)
non_empty_points = None
# ########## filter the points using the scaffold ########## #
if self._scaffold_ready and self.scaffold_filter_points:
# pyre-ignore[29]
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
with torch.no_grad():
# pyre-ignore[29]
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
points = points[non_empty_points]
directions = directions[non_empty_points]
if len(points) == 0:
warnings.warn(
"The scaffold has filtered all the points."
@ -262,8 +264,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
)
# ########## calculate color and density ########## #
rays_densities, rays_colors = self.calculate_density_and_color(
points, directions, camera
rays_densities, rays_colors = self._calculate_density_and_color(
points, directions, camera, non_empty_points, num_points_per_ray
)
if not (self._scaffold_ready and self.scaffold_filter_points):
@ -283,9 +285,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
rays_colors_combined = rays_colors.new_zeros(
(math.prod(input_shape[:-1]), rays_colors.shape[-1])
)
# pyre-ignore[61]
assert non_empty_points is not None
rays_densities_combined[non_empty_points] = rays_densities
# pyre-ignore[61]
rays_colors_combined[non_empty_points] = rays_colors
return (
@ -294,11 +295,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
{},
)
def calculate_density_and_color(
def _calculate_density_and_color(
self,
points: torch.Tensor,
directions: torch.Tensor,
camera: Optional[CamerasBase] = None,
camera: Optional[CamerasBase],
non_empty_points: Optional[torch.Tensor],
num_points_per_ray: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculates density and color at `points`.
@ -306,11 +309,14 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
Args:
points: points at which to calculate density and color.
Tensor of shape [..., 3].
directions: from which directions are the points viewed
Tensor of shape [..., 3].
Tensor of shape [n_points, 3].
directions: from which directions are the points viewed.
One per ray. Tensor of shape [n_rays, 3].
camera: A camera model which will be used to transform the viewing
directions
non_empty_points: indices of points which weren't filtered out;
used for expanding directions
num_points_per_ray: number of points per ray, needed to expand directions.
Returns:
Tuple of color (tensor of shape [..., 3]) and density
(tensor of shape [..., 1])
@ -323,20 +329,24 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
with torch.cuda.stream(other_stream):
# rays_densities.shape =
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
rays_densities = self.get_density(points)
rays_densities = self._get_density(points)
# rays_colors.shape =
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
rays_colors = self.get_color(points, camera, directions)
rays_colors = self._get_color(
points, camera, directions, non_empty_points, num_points_per_ray
)
current_stream.wait_stream(other_stream)
else:
# Same calculation as above, just serial.
rays_densities = self.get_density(points)
rays_colors = self.get_color(points, camera, directions)
rays_densities = self._get_density(points)
rays_colors = self._get_color(
points, camera, directions, non_empty_points, num_points_per_ray
)
return rays_densities, rays_colors
def get_density(self, points: torch.Tensor) -> torch.Tensor:
def _get_density(self, points: torch.Tensor) -> torch.Tensor:
"""
Calculates density at points:
1) Evaluates the voxel grid on points
@ -356,11 +366,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
# shape = [..., density_dim]
return self.decoder_density(harmonic_embedding_density)
def get_color(
def _get_color(
self,
points: torch.Tensor,
camera: Optional[CamerasBase],
directions: torch.Tensor,
non_empty_points: Optional[torch.Tensor],
num_points_per_ray: int,
) -> torch.Tensor:
"""
Calculates color at points using the viewing direction:
@ -376,6 +388,9 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
directions
directions: A tensor of shape `(..., 3)`
containing the direction vectors of sampling rays in world coords.
non_empty_points: indices of points which weren't filtered out;
used for expanding directions
num_points_per_ray: number of points per ray, needed to expand directions.
"""
# ########## transform direction ########## #
if self.xyz_ray_dir_in_camera_coords:
@ -400,12 +415,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
rays_directions_normed
)
n_rays = directions.shape[0]
points_per_ray: int = points.shape[0] // n_rays
harmonic_embedding_dir = torch.repeat_interleave(
harmonic_embedding_dir, points_per_ray, dim=0
harmonic_embedding_dir, num_points_per_ray, dim=0
)
if non_empty_points is not None:
harmonic_embedding_dir = harmonic_embedding_dir[non_empty_points]
# total color embedding is concatenation of the harmonic embedding of voxel grid
# output and harmonic embedding of the normalized direction
@ -505,7 +519,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
)
for k in range(0, points.shape[-1], chunk_size):
points_in_planes = points[..., k : k + chunk_size]
planes.append(self.get_density(points_in_planes)[..., 0])
planes.append(self._get_density(points_in_planes)[..., 0])
density_cube = torch.cat(planes, dim=-1)
density_cube = torch.nn.functional.max_pool3d(

View File

@ -89,7 +89,7 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
out.append(torch.tensor([[0.0]]))
return torch.cat(out).view(*inshape[:-1], 1).to(device)
func.get_density = new_density
func._get_density = new_density
func._get_scaffold(0)
points = torch.tensor(
@ -136,15 +136,15 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
assert torch.all(scaffold(points)), (scaffold(points), points.shape)
return points.sum(dim=-1, keepdim=True)
def new_color(points, camera, directions):
def new_color(points, camera, directions, non_empty_points, num_points_per_ray):
# check if all passed points should be passed here
assert torch.all(scaffold(points)) # , (scaffold(points), points)
return points * 2
# check both computation paths that they contain only points
# which are not in empty space
func.get_density = new_density
func.get_color = new_color
func._get_density = new_density
func._get_color = new_color
func.voxel_grid_scaffold.forward = scaffold
func._scaffold_ready = True