voxel_grid_implicit_function scaffold fixes

Summary: Fix indexing of directions after filtering of points by scaffold.

Reviewed By: shapovalov

Differential Revision: D40853482

fbshipit-source-id: 9cfdb981e97cb82edcd27632c5848537ed2c6837
This commit is contained in:
Jeremy Reizenstein 2022-11-03 05:46:31 -07:00 committed by Facebook GitHub Bot
parent e4a3298149
commit a1f2ded58a
2 changed files with 44 additions and 30 deletions

View File

@ -7,7 +7,7 @@
import math import math
import warnings import warnings
from dataclasses import fields from dataclasses import fields
from typing import Callable, Dict, Optional, Tuple, Union from typing import Callable, Dict, Optional, Tuple
import torch import torch
@ -118,11 +118,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
the calculation.) the calculation.)
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
voxel grid which stores scaffold voxel grid which stores scaffold
scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than scaffold_empty_space_threshold (float): if `self._get_density` evaluates to less than
this it will be considered as empty space and the scaffold at that point would this it will be considered as empty space and the scaffold at that point would
evaluate as empty space. evaluate as empty space.
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
at the same time. To calculate the scaffold we need to query `get_density()` at at the same time. To calculate the scaffold we need to query `_get_density()` at
every voxel, this calculation can be split into scaffold depth number of xy plane every voxel, this calculation can be split into scaffold depth number of xy plane
calculations if you want the lowest memory usage, one calculation to calculate the calculations if you want the lowest memory usage, one calculation to calculate the
whole scaffold, but with higher memory footprint or any other number of planes. whole scaffold, but with higher memory footprint or any other number of planes.
@ -242,14 +242,16 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
points = ray_bundle_to_ray_points(ray_bundle) points = ray_bundle_to_ray_points(ray_bundle)
directions = ray_bundle.directions.reshape(-1, 3) directions = ray_bundle.directions.reshape(-1, 3)
input_shape = points.shape input_shape = points.shape
num_points_per_ray = input_shape[-2]
points = points.view(-1, 3) points = points.view(-1, 3)
non_empty_points = None
# ########## filter the points using the scaffold ########## # # ########## filter the points using the scaffold ########## #
if self._scaffold_ready and self.scaffold_filter_points: if self._scaffold_ready and self.scaffold_filter_points:
# pyre-ignore[29] with torch.no_grad():
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0 # pyre-ignore[29]
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
points = points[non_empty_points] points = points[non_empty_points]
directions = directions[non_empty_points]
if len(points) == 0: if len(points) == 0:
warnings.warn( warnings.warn(
"The scaffold has filtered all the points." "The scaffold has filtered all the points."
@ -262,8 +264,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
) )
# ########## calculate color and density ########## # # ########## calculate color and density ########## #
rays_densities, rays_colors = self.calculate_density_and_color( rays_densities, rays_colors = self._calculate_density_and_color(
points, directions, camera points, directions, camera, non_empty_points, num_points_per_ray
) )
if not (self._scaffold_ready and self.scaffold_filter_points): if not (self._scaffold_ready and self.scaffold_filter_points):
@ -283,9 +285,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
rays_colors_combined = rays_colors.new_zeros( rays_colors_combined = rays_colors.new_zeros(
(math.prod(input_shape[:-1]), rays_colors.shape[-1]) (math.prod(input_shape[:-1]), rays_colors.shape[-1])
) )
# pyre-ignore[61] assert non_empty_points is not None
rays_densities_combined[non_empty_points] = rays_densities rays_densities_combined[non_empty_points] = rays_densities
# pyre-ignore[61]
rays_colors_combined[non_empty_points] = rays_colors rays_colors_combined[non_empty_points] = rays_colors
return ( return (
@ -294,11 +295,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
{}, {},
) )
def calculate_density_and_color( def _calculate_density_and_color(
self, self,
points: torch.Tensor, points: torch.Tensor,
directions: torch.Tensor, directions: torch.Tensor,
camera: Optional[CamerasBase] = None, camera: Optional[CamerasBase],
non_empty_points: Optional[torch.Tensor],
num_points_per_ray: int,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Calculates density and color at `points`. Calculates density and color at `points`.
@ -306,11 +309,14 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
Args: Args:
points: points at which to calculate density and color. points: points at which to calculate density and color.
Tensor of shape [..., 3]. Tensor of shape [n_points, 3].
directions: from which directions are the points viewed directions: from which directions are the points viewed.
Tensor of shape [..., 3]. One per ray. Tensor of shape [n_rays, 3].
camera: A camera model which will be used to transform the viewing camera: A camera model which will be used to transform the viewing
directions directions
non_empty_points: indices of points which weren't filtered out;
used for expanding directions
num_points_per_ray: number of points per ray, needed to expand directions.
Returns: Returns:
Tuple of color (tensor of shape [..., 3]) and density Tuple of color (tensor of shape [..., 3]) and density
(tensor of shape [..., 1]) (tensor of shape [..., 1])
@ -323,20 +329,24 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
with torch.cuda.stream(other_stream): with torch.cuda.stream(other_stream):
# rays_densities.shape = # rays_densities.shape =
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim] # [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
rays_densities = self.get_density(points) rays_densities = self._get_density(points)
# rays_colors.shape = # rays_colors.shape =
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim] # [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
rays_colors = self.get_color(points, camera, directions) rays_colors = self._get_color(
points, camera, directions, non_empty_points, num_points_per_ray
)
current_stream.wait_stream(other_stream) current_stream.wait_stream(other_stream)
else: else:
# Same calculation as above, just serial. # Same calculation as above, just serial.
rays_densities = self.get_density(points) rays_densities = self._get_density(points)
rays_colors = self.get_color(points, camera, directions) rays_colors = self._get_color(
points, camera, directions, non_empty_points, num_points_per_ray
)
return rays_densities, rays_colors return rays_densities, rays_colors
def get_density(self, points: torch.Tensor) -> torch.Tensor: def _get_density(self, points: torch.Tensor) -> torch.Tensor:
""" """
Calculates density at points: Calculates density at points:
1) Evaluates the voxel grid on points 1) Evaluates the voxel grid on points
@ -356,11 +366,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
# shape = [..., density_dim] # shape = [..., density_dim]
return self.decoder_density(harmonic_embedding_density) return self.decoder_density(harmonic_embedding_density)
def get_color( def _get_color(
self, self,
points: torch.Tensor, points: torch.Tensor,
camera: Optional[CamerasBase], camera: Optional[CamerasBase],
directions: torch.Tensor, directions: torch.Tensor,
non_empty_points: Optional[torch.Tensor],
num_points_per_ray: int,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Calculates color at points using the viewing direction: Calculates color at points using the viewing direction:
@ -376,6 +388,9 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
directions directions
directions: A tensor of shape `(..., 3)` directions: A tensor of shape `(..., 3)`
containing the direction vectors of sampling rays in world coords. containing the direction vectors of sampling rays in world coords.
non_empty_points: indices of points which weren't filtered out;
used for expanding directions
num_points_per_ray: number of points per ray, needed to expand directions.
""" """
# ########## transform direction ########## # # ########## transform direction ########## #
if self.xyz_ray_dir_in_camera_coords: if self.xyz_ray_dir_in_camera_coords:
@ -400,12 +415,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
rays_directions_normed rays_directions_normed
) )
n_rays = directions.shape[0]
points_per_ray: int = points.shape[0] // n_rays
harmonic_embedding_dir = torch.repeat_interleave( harmonic_embedding_dir = torch.repeat_interleave(
harmonic_embedding_dir, points_per_ray, dim=0 harmonic_embedding_dir, num_points_per_ray, dim=0
) )
if non_empty_points is not None:
harmonic_embedding_dir = harmonic_embedding_dir[non_empty_points]
# total color embedding is concatenation of the harmonic embedding of voxel grid # total color embedding is concatenation of the harmonic embedding of voxel grid
# output and harmonic embedding of the normalized direction # output and harmonic embedding of the normalized direction
@ -505,7 +519,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
) )
for k in range(0, points.shape[-1], chunk_size): for k in range(0, points.shape[-1], chunk_size):
points_in_planes = points[..., k : k + chunk_size] points_in_planes = points[..., k : k + chunk_size]
planes.append(self.get_density(points_in_planes)[..., 0]) planes.append(self._get_density(points_in_planes)[..., 0])
density_cube = torch.cat(planes, dim=-1) density_cube = torch.cat(planes, dim=-1)
density_cube = torch.nn.functional.max_pool3d( density_cube = torch.nn.functional.max_pool3d(

View File

@ -89,7 +89,7 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
out.append(torch.tensor([[0.0]])) out.append(torch.tensor([[0.0]]))
return torch.cat(out).view(*inshape[:-1], 1).to(device) return torch.cat(out).view(*inshape[:-1], 1).to(device)
func.get_density = new_density func._get_density = new_density
func._get_scaffold(0) func._get_scaffold(0)
points = torch.tensor( points = torch.tensor(
@ -136,15 +136,15 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
assert torch.all(scaffold(points)), (scaffold(points), points.shape) assert torch.all(scaffold(points)), (scaffold(points), points.shape)
return points.sum(dim=-1, keepdim=True) return points.sum(dim=-1, keepdim=True)
def new_color(points, camera, directions): def new_color(points, camera, directions, non_empty_points, num_points_per_ray):
# check if all passed points should be passed here # check if all passed points should be passed here
assert torch.all(scaffold(points)) # , (scaffold(points), points) assert torch.all(scaffold(points)) # , (scaffold(points), points)
return points * 2 return points * 2
# check both computation paths that they contain only points # check both computation paths that they contain only points
# which are not in empty space # which are not in empty space
func.get_density = new_density func._get_density = new_density
func.get_color = new_color func._get_color = new_color
func.voxel_grid_scaffold.forward = scaffold func.voxel_grid_scaffold.forward = scaffold
func._scaffold_ready = True func._scaffold_ready = True