mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
voxel_grid_implicit_function scaffold fixes
Summary: Fix indexing of directions after filtering of points by scaffold. Reviewed By: shapovalov Differential Revision: D40853482 fbshipit-source-id: 9cfdb981e97cb82edcd27632c5848537ed2c6837
This commit is contained in:
parent
e4a3298149
commit
a1f2ded58a
@ -7,7 +7,7 @@
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import fields
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -118,11 +118,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
the calculation.)
|
||||
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
|
||||
voxel grid which stores scaffold
|
||||
scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than
|
||||
scaffold_empty_space_threshold (float): if `self._get_density` evaluates to less than
|
||||
this it will be considered as empty space and the scaffold at that point would
|
||||
evaluate as empty space.
|
||||
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
|
||||
at the same time. To calculate the scaffold we need to query `get_density()` at
|
||||
at the same time. To calculate the scaffold we need to query `_get_density()` at
|
||||
every voxel, this calculation can be split into scaffold depth number of xy plane
|
||||
calculations if you want the lowest memory usage, one calculation to calculate the
|
||||
whole scaffold, but with higher memory footprint or any other number of planes.
|
||||
@ -242,14 +242,16 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
points = ray_bundle_to_ray_points(ray_bundle)
|
||||
directions = ray_bundle.directions.reshape(-1, 3)
|
||||
input_shape = points.shape
|
||||
num_points_per_ray = input_shape[-2]
|
||||
points = points.view(-1, 3)
|
||||
non_empty_points = None
|
||||
|
||||
# ########## filter the points using the scaffold ########## #
|
||||
if self._scaffold_ready and self.scaffold_filter_points:
|
||||
with torch.no_grad():
|
||||
# pyre-ignore[29]
|
||||
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||
points = points[non_empty_points]
|
||||
directions = directions[non_empty_points]
|
||||
if len(points) == 0:
|
||||
warnings.warn(
|
||||
"The scaffold has filtered all the points."
|
||||
@ -262,8 +264,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
)
|
||||
|
||||
# ########## calculate color and density ########## #
|
||||
rays_densities, rays_colors = self.calculate_density_and_color(
|
||||
points, directions, camera
|
||||
rays_densities, rays_colors = self._calculate_density_and_color(
|
||||
points, directions, camera, non_empty_points, num_points_per_ray
|
||||
)
|
||||
|
||||
if not (self._scaffold_ready and self.scaffold_filter_points):
|
||||
@ -283,9 +285,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
rays_colors_combined = rays_colors.new_zeros(
|
||||
(math.prod(input_shape[:-1]), rays_colors.shape[-1])
|
||||
)
|
||||
# pyre-ignore[61]
|
||||
assert non_empty_points is not None
|
||||
rays_densities_combined[non_empty_points] = rays_densities
|
||||
# pyre-ignore[61]
|
||||
rays_colors_combined[non_empty_points] = rays_colors
|
||||
|
||||
return (
|
||||
@ -294,11 +295,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
{},
|
||||
)
|
||||
|
||||
def calculate_density_and_color(
|
||||
def _calculate_density_and_color(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
directions: torch.Tensor,
|
||||
camera: Optional[CamerasBase] = None,
|
||||
camera: Optional[CamerasBase],
|
||||
non_empty_points: Optional[torch.Tensor],
|
||||
num_points_per_ray: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Calculates density and color at `points`.
|
||||
@ -306,11 +309,14 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
|
||||
Args:
|
||||
points: points at which to calculate density and color.
|
||||
Tensor of shape [..., 3].
|
||||
directions: from which directions are the points viewed
|
||||
Tensor of shape [..., 3].
|
||||
Tensor of shape [n_points, 3].
|
||||
directions: from which directions are the points viewed.
|
||||
One per ray. Tensor of shape [n_rays, 3].
|
||||
camera: A camera model which will be used to transform the viewing
|
||||
directions
|
||||
non_empty_points: indices of points which weren't filtered out;
|
||||
used for expanding directions
|
||||
num_points_per_ray: number of points per ray, needed to expand directions.
|
||||
Returns:
|
||||
Tuple of color (tensor of shape [..., 3]) and density
|
||||
(tensor of shape [..., 1])
|
||||
@ -323,20 +329,24 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
with torch.cuda.stream(other_stream):
|
||||
# rays_densities.shape =
|
||||
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
|
||||
rays_densities = self.get_density(points)
|
||||
rays_densities = self._get_density(points)
|
||||
|
||||
# rays_colors.shape =
|
||||
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
|
||||
rays_colors = self.get_color(points, camera, directions)
|
||||
rays_colors = self._get_color(
|
||||
points, camera, directions, non_empty_points, num_points_per_ray
|
||||
)
|
||||
|
||||
current_stream.wait_stream(other_stream)
|
||||
else:
|
||||
# Same calculation as above, just serial.
|
||||
rays_densities = self.get_density(points)
|
||||
rays_colors = self.get_color(points, camera, directions)
|
||||
rays_densities = self._get_density(points)
|
||||
rays_colors = self._get_color(
|
||||
points, camera, directions, non_empty_points, num_points_per_ray
|
||||
)
|
||||
return rays_densities, rays_colors
|
||||
|
||||
def get_density(self, points: torch.Tensor) -> torch.Tensor:
|
||||
def _get_density(self, points: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates density at points:
|
||||
1) Evaluates the voxel grid on points
|
||||
@ -356,11 +366,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
# shape = [..., density_dim]
|
||||
return self.decoder_density(harmonic_embedding_density)
|
||||
|
||||
def get_color(
|
||||
def _get_color(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
camera: Optional[CamerasBase],
|
||||
directions: torch.Tensor,
|
||||
non_empty_points: Optional[torch.Tensor],
|
||||
num_points_per_ray: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates color at points using the viewing direction:
|
||||
@ -376,6 +388,9 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
directions
|
||||
directions: A tensor of shape `(..., 3)`
|
||||
containing the direction vectors of sampling rays in world coords.
|
||||
non_empty_points: indices of points which weren't filtered out;
|
||||
used for expanding directions
|
||||
num_points_per_ray: number of points per ray, needed to expand directions.
|
||||
"""
|
||||
# ########## transform direction ########## #
|
||||
if self.xyz_ray_dir_in_camera_coords:
|
||||
@ -400,12 +415,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
rays_directions_normed
|
||||
)
|
||||
|
||||
n_rays = directions.shape[0]
|
||||
points_per_ray: int = points.shape[0] // n_rays
|
||||
|
||||
harmonic_embedding_dir = torch.repeat_interleave(
|
||||
harmonic_embedding_dir, points_per_ray, dim=0
|
||||
harmonic_embedding_dir, num_points_per_ray, dim=0
|
||||
)
|
||||
if non_empty_points is not None:
|
||||
harmonic_embedding_dir = harmonic_embedding_dir[non_empty_points]
|
||||
|
||||
# total color embedding is concatenation of the harmonic embedding of voxel grid
|
||||
# output and harmonic embedding of the normalized direction
|
||||
@ -505,7 +519,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
)
|
||||
for k in range(0, points.shape[-1], chunk_size):
|
||||
points_in_planes = points[..., k : k + chunk_size]
|
||||
planes.append(self.get_density(points_in_planes)[..., 0])
|
||||
planes.append(self._get_density(points_in_planes)[..., 0])
|
||||
|
||||
density_cube = torch.cat(planes, dim=-1)
|
||||
density_cube = torch.nn.functional.max_pool3d(
|
||||
|
@ -89,7 +89,7 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
|
||||
out.append(torch.tensor([[0.0]]))
|
||||
return torch.cat(out).view(*inshape[:-1], 1).to(device)
|
||||
|
||||
func.get_density = new_density
|
||||
func._get_density = new_density
|
||||
func._get_scaffold(0)
|
||||
|
||||
points = torch.tensor(
|
||||
@ -136,15 +136,15 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
|
||||
assert torch.all(scaffold(points)), (scaffold(points), points.shape)
|
||||
return points.sum(dim=-1, keepdim=True)
|
||||
|
||||
def new_color(points, camera, directions):
|
||||
def new_color(points, camera, directions, non_empty_points, num_points_per_ray):
|
||||
# check if all passed points should be passed here
|
||||
assert torch.all(scaffold(points)) # , (scaffold(points), points)
|
||||
return points * 2
|
||||
|
||||
# check both computation paths that they contain only points
|
||||
# which are not in empty space
|
||||
func.get_density = new_density
|
||||
func.get_color = new_color
|
||||
func._get_density = new_density
|
||||
func._get_color = new_color
|
||||
func.voxel_grid_scaffold.forward = scaffold
|
||||
func._scaffold_ready = True
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user