From a1f2ded58a502f0d65c86ff6c86f417689a5c8d4 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 3 Nov 2022 05:46:31 -0700 Subject: [PATCH] voxel_grid_implicit_function scaffold fixes Summary: Fix indexing of directions after filtering of points by scaffold. Reviewed By: shapovalov Differential Revision: D40853482 fbshipit-source-id: 9cfdb981e97cb82edcd27632c5848537ed2c6837 --- .../voxel_grid_implicit_function.py | 66 +++++++++++-------- .../test_voxel_grid_implicit_function.py | 8 +-- 2 files changed, 44 insertions(+), 30 deletions(-) diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py index 3283cfe0..f106e10f 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py @@ -7,7 +7,7 @@ import math import warnings from dataclasses import fields -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple import torch @@ -118,11 +118,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): the calculation.) scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying voxel grid which stores scaffold - scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than + scaffold_empty_space_threshold (float): if `self._get_density` evaluates to less than this it will be considered as empty space and the scaffold at that point would evaluate as empty space. scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate - at the same time. To calculate the scaffold we need to query `get_density()` at + at the same time. To calculate the scaffold we need to query `_get_density()` at every voxel, this calculation can be split into scaffold depth number of xy plane calculations if you want the lowest memory usage, one calculation to calculate the whole scaffold, but with higher memory footprint or any other number of planes. @@ -242,14 +242,16 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): points = ray_bundle_to_ray_points(ray_bundle) directions = ray_bundle.directions.reshape(-1, 3) input_shape = points.shape + num_points_per_ray = input_shape[-2] points = points.view(-1, 3) + non_empty_points = None # ########## filter the points using the scaffold ########## # if self._scaffold_ready and self.scaffold_filter_points: - # pyre-ignore[29] - non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0 + with torch.no_grad(): + # pyre-ignore[29] + non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0 points = points[non_empty_points] - directions = directions[non_empty_points] if len(points) == 0: warnings.warn( "The scaffold has filtered all the points." @@ -262,8 +264,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): ) # ########## calculate color and density ########## # - rays_densities, rays_colors = self.calculate_density_and_color( - points, directions, camera + rays_densities, rays_colors = self._calculate_density_and_color( + points, directions, camera, non_empty_points, num_points_per_ray ) if not (self._scaffold_ready and self.scaffold_filter_points): @@ -283,9 +285,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): rays_colors_combined = rays_colors.new_zeros( (math.prod(input_shape[:-1]), rays_colors.shape[-1]) ) - # pyre-ignore[61] + assert non_empty_points is not None rays_densities_combined[non_empty_points] = rays_densities - # pyre-ignore[61] rays_colors_combined[non_empty_points] = rays_colors return ( @@ -294,11 +295,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): {}, ) - def calculate_density_and_color( + def _calculate_density_and_color( self, points: torch.Tensor, directions: torch.Tensor, - camera: Optional[CamerasBase] = None, + camera: Optional[CamerasBase], + non_empty_points: Optional[torch.Tensor], + num_points_per_ray: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Calculates density and color at `points`. @@ -306,11 +309,14 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): Args: points: points at which to calculate density and color. - Tensor of shape [..., 3]. - directions: from which directions are the points viewed - Tensor of shape [..., 3]. + Tensor of shape [n_points, 3]. + directions: from which directions are the points viewed. + One per ray. Tensor of shape [n_rays, 3]. camera: A camera model which will be used to transform the viewing directions + non_empty_points: indices of points which weren't filtered out; + used for expanding directions + num_points_per_ray: number of points per ray, needed to expand directions. Returns: Tuple of color (tensor of shape [..., 3]) and density (tensor of shape [..., 1]) @@ -323,20 +329,24 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): with torch.cuda.stream(other_stream): # rays_densities.shape = # [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim] - rays_densities = self.get_density(points) + rays_densities = self._get_density(points) # rays_colors.shape = # [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim] - rays_colors = self.get_color(points, camera, directions) + rays_colors = self._get_color( + points, camera, directions, non_empty_points, num_points_per_ray + ) current_stream.wait_stream(other_stream) else: # Same calculation as above, just serial. - rays_densities = self.get_density(points) - rays_colors = self.get_color(points, camera, directions) + rays_densities = self._get_density(points) + rays_colors = self._get_color( + points, camera, directions, non_empty_points, num_points_per_ray + ) return rays_densities, rays_colors - def get_density(self, points: torch.Tensor) -> torch.Tensor: + def _get_density(self, points: torch.Tensor) -> torch.Tensor: """ Calculates density at points: 1) Evaluates the voxel grid on points @@ -356,11 +366,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): # shape = [..., density_dim] return self.decoder_density(harmonic_embedding_density) - def get_color( + def _get_color( self, points: torch.Tensor, camera: Optional[CamerasBase], directions: torch.Tensor, + non_empty_points: Optional[torch.Tensor], + num_points_per_ray: int, ) -> torch.Tensor: """ Calculates color at points using the viewing direction: @@ -376,6 +388,9 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): directions directions: A tensor of shape `(..., 3)` containing the direction vectors of sampling rays in world coords. + non_empty_points: indices of points which weren't filtered out; + used for expanding directions + num_points_per_ray: number of points per ray, needed to expand directions. """ # ########## transform direction ########## # if self.xyz_ray_dir_in_camera_coords: @@ -400,12 +415,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): rays_directions_normed ) - n_rays = directions.shape[0] - points_per_ray: int = points.shape[0] // n_rays - harmonic_embedding_dir = torch.repeat_interleave( - harmonic_embedding_dir, points_per_ray, dim=0 + harmonic_embedding_dir, num_points_per_ray, dim=0 ) + if non_empty_points is not None: + harmonic_embedding_dir = harmonic_embedding_dir[non_empty_points] # total color embedding is concatenation of the harmonic embedding of voxel grid # output and harmonic embedding of the normalized direction @@ -505,7 +519,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): ) for k in range(0, points.shape[-1], chunk_size): points_in_planes = points[..., k : k + chunk_size] - planes.append(self.get_density(points_in_planes)[..., 0]) + planes.append(self._get_density(points_in_planes)[..., 0]) density_cube = torch.cat(planes, dim=-1) density_cube = torch.nn.functional.max_pool3d( diff --git a/tests/implicitron/test_voxel_grid_implicit_function.py b/tests/implicitron/test_voxel_grid_implicit_function.py index b5d482c8..9727ba98 100644 --- a/tests/implicitron/test_voxel_grid_implicit_function.py +++ b/tests/implicitron/test_voxel_grid_implicit_function.py @@ -89,7 +89,7 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase): out.append(torch.tensor([[0.0]])) return torch.cat(out).view(*inshape[:-1], 1).to(device) - func.get_density = new_density + func._get_density = new_density func._get_scaffold(0) points = torch.tensor( @@ -136,15 +136,15 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase): assert torch.all(scaffold(points)), (scaffold(points), points.shape) return points.sum(dim=-1, keepdim=True) - def new_color(points, camera, directions): + def new_color(points, camera, directions, non_empty_points, num_points_per_ray): # check if all passed points should be passed here assert torch.all(scaffold(points)) # , (scaffold(points), points) return points * 2 # check both computation paths that they contain only points # which are not in empty space - func.get_density = new_density - func.get_color = new_color + func._get_density = new_density + func._get_color = new_color func.voxel_grid_scaffold.forward = scaffold func._scaffold_ready = True