mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
voxel_grid_implicit_function scaffold fixes
Summary: Fix indexing of directions after filtering of points by scaffold. Reviewed By: shapovalov Differential Revision: D40853482 fbshipit-source-id: 9cfdb981e97cb82edcd27632c5848537ed2c6837
This commit is contained in:
parent
e4a3298149
commit
a1f2ded58a
@ -7,7 +7,7 @@
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from typing import Callable, Dict, Optional, Tuple, Union
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -118,11 +118,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
the calculation.)
|
the calculation.)
|
||||||
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
|
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
|
||||||
voxel grid which stores scaffold
|
voxel grid which stores scaffold
|
||||||
scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than
|
scaffold_empty_space_threshold (float): if `self._get_density` evaluates to less than
|
||||||
this it will be considered as empty space and the scaffold at that point would
|
this it will be considered as empty space and the scaffold at that point would
|
||||||
evaluate as empty space.
|
evaluate as empty space.
|
||||||
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
|
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
|
||||||
at the same time. To calculate the scaffold we need to query `get_density()` at
|
at the same time. To calculate the scaffold we need to query `_get_density()` at
|
||||||
every voxel, this calculation can be split into scaffold depth number of xy plane
|
every voxel, this calculation can be split into scaffold depth number of xy plane
|
||||||
calculations if you want the lowest memory usage, one calculation to calculate the
|
calculations if you want the lowest memory usage, one calculation to calculate the
|
||||||
whole scaffold, but with higher memory footprint or any other number of planes.
|
whole scaffold, but with higher memory footprint or any other number of planes.
|
||||||
@ -242,14 +242,16 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
points = ray_bundle_to_ray_points(ray_bundle)
|
points = ray_bundle_to_ray_points(ray_bundle)
|
||||||
directions = ray_bundle.directions.reshape(-1, 3)
|
directions = ray_bundle.directions.reshape(-1, 3)
|
||||||
input_shape = points.shape
|
input_shape = points.shape
|
||||||
|
num_points_per_ray = input_shape[-2]
|
||||||
points = points.view(-1, 3)
|
points = points.view(-1, 3)
|
||||||
|
non_empty_points = None
|
||||||
|
|
||||||
# ########## filter the points using the scaffold ########## #
|
# ########## filter the points using the scaffold ########## #
|
||||||
if self._scaffold_ready and self.scaffold_filter_points:
|
if self._scaffold_ready and self.scaffold_filter_points:
|
||||||
# pyre-ignore[29]
|
with torch.no_grad():
|
||||||
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
|
# pyre-ignore[29]
|
||||||
|
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||||
points = points[non_empty_points]
|
points = points[non_empty_points]
|
||||||
directions = directions[non_empty_points]
|
|
||||||
if len(points) == 0:
|
if len(points) == 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The scaffold has filtered all the points."
|
"The scaffold has filtered all the points."
|
||||||
@ -262,8 +264,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ########## calculate color and density ########## #
|
# ########## calculate color and density ########## #
|
||||||
rays_densities, rays_colors = self.calculate_density_and_color(
|
rays_densities, rays_colors = self._calculate_density_and_color(
|
||||||
points, directions, camera
|
points, directions, camera, non_empty_points, num_points_per_ray
|
||||||
)
|
)
|
||||||
|
|
||||||
if not (self._scaffold_ready and self.scaffold_filter_points):
|
if not (self._scaffold_ready and self.scaffold_filter_points):
|
||||||
@ -283,9 +285,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
rays_colors_combined = rays_colors.new_zeros(
|
rays_colors_combined = rays_colors.new_zeros(
|
||||||
(math.prod(input_shape[:-1]), rays_colors.shape[-1])
|
(math.prod(input_shape[:-1]), rays_colors.shape[-1])
|
||||||
)
|
)
|
||||||
# pyre-ignore[61]
|
assert non_empty_points is not None
|
||||||
rays_densities_combined[non_empty_points] = rays_densities
|
rays_densities_combined[non_empty_points] = rays_densities
|
||||||
# pyre-ignore[61]
|
|
||||||
rays_colors_combined[non_empty_points] = rays_colors
|
rays_colors_combined[non_empty_points] = rays_colors
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -294,11 +295,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
{},
|
{},
|
||||||
)
|
)
|
||||||
|
|
||||||
def calculate_density_and_color(
|
def _calculate_density_and_color(
|
||||||
self,
|
self,
|
||||||
points: torch.Tensor,
|
points: torch.Tensor,
|
||||||
directions: torch.Tensor,
|
directions: torch.Tensor,
|
||||||
camera: Optional[CamerasBase] = None,
|
camera: Optional[CamerasBase],
|
||||||
|
non_empty_points: Optional[torch.Tensor],
|
||||||
|
num_points_per_ray: int,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Calculates density and color at `points`.
|
Calculates density and color at `points`.
|
||||||
@ -306,11 +309,14 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
points: points at which to calculate density and color.
|
points: points at which to calculate density and color.
|
||||||
Tensor of shape [..., 3].
|
Tensor of shape [n_points, 3].
|
||||||
directions: from which directions are the points viewed
|
directions: from which directions are the points viewed.
|
||||||
Tensor of shape [..., 3].
|
One per ray. Tensor of shape [n_rays, 3].
|
||||||
camera: A camera model which will be used to transform the viewing
|
camera: A camera model which will be used to transform the viewing
|
||||||
directions
|
directions
|
||||||
|
non_empty_points: indices of points which weren't filtered out;
|
||||||
|
used for expanding directions
|
||||||
|
num_points_per_ray: number of points per ray, needed to expand directions.
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of color (tensor of shape [..., 3]) and density
|
Tuple of color (tensor of shape [..., 3]) and density
|
||||||
(tensor of shape [..., 1])
|
(tensor of shape [..., 1])
|
||||||
@ -323,20 +329,24 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
with torch.cuda.stream(other_stream):
|
with torch.cuda.stream(other_stream):
|
||||||
# rays_densities.shape =
|
# rays_densities.shape =
|
||||||
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
|
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
|
||||||
rays_densities = self.get_density(points)
|
rays_densities = self._get_density(points)
|
||||||
|
|
||||||
# rays_colors.shape =
|
# rays_colors.shape =
|
||||||
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
|
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
|
||||||
rays_colors = self.get_color(points, camera, directions)
|
rays_colors = self._get_color(
|
||||||
|
points, camera, directions, non_empty_points, num_points_per_ray
|
||||||
|
)
|
||||||
|
|
||||||
current_stream.wait_stream(other_stream)
|
current_stream.wait_stream(other_stream)
|
||||||
else:
|
else:
|
||||||
# Same calculation as above, just serial.
|
# Same calculation as above, just serial.
|
||||||
rays_densities = self.get_density(points)
|
rays_densities = self._get_density(points)
|
||||||
rays_colors = self.get_color(points, camera, directions)
|
rays_colors = self._get_color(
|
||||||
|
points, camera, directions, non_empty_points, num_points_per_ray
|
||||||
|
)
|
||||||
return rays_densities, rays_colors
|
return rays_densities, rays_colors
|
||||||
|
|
||||||
def get_density(self, points: torch.Tensor) -> torch.Tensor:
|
def _get_density(self, points: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Calculates density at points:
|
Calculates density at points:
|
||||||
1) Evaluates the voxel grid on points
|
1) Evaluates the voxel grid on points
|
||||||
@ -356,11 +366,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
# shape = [..., density_dim]
|
# shape = [..., density_dim]
|
||||||
return self.decoder_density(harmonic_embedding_density)
|
return self.decoder_density(harmonic_embedding_density)
|
||||||
|
|
||||||
def get_color(
|
def _get_color(
|
||||||
self,
|
self,
|
||||||
points: torch.Tensor,
|
points: torch.Tensor,
|
||||||
camera: Optional[CamerasBase],
|
camera: Optional[CamerasBase],
|
||||||
directions: torch.Tensor,
|
directions: torch.Tensor,
|
||||||
|
non_empty_points: Optional[torch.Tensor],
|
||||||
|
num_points_per_ray: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Calculates color at points using the viewing direction:
|
Calculates color at points using the viewing direction:
|
||||||
@ -376,6 +388,9 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
directions
|
directions
|
||||||
directions: A tensor of shape `(..., 3)`
|
directions: A tensor of shape `(..., 3)`
|
||||||
containing the direction vectors of sampling rays in world coords.
|
containing the direction vectors of sampling rays in world coords.
|
||||||
|
non_empty_points: indices of points which weren't filtered out;
|
||||||
|
used for expanding directions
|
||||||
|
num_points_per_ray: number of points per ray, needed to expand directions.
|
||||||
"""
|
"""
|
||||||
# ########## transform direction ########## #
|
# ########## transform direction ########## #
|
||||||
if self.xyz_ray_dir_in_camera_coords:
|
if self.xyz_ray_dir_in_camera_coords:
|
||||||
@ -400,12 +415,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
rays_directions_normed
|
rays_directions_normed
|
||||||
)
|
)
|
||||||
|
|
||||||
n_rays = directions.shape[0]
|
|
||||||
points_per_ray: int = points.shape[0] // n_rays
|
|
||||||
|
|
||||||
harmonic_embedding_dir = torch.repeat_interleave(
|
harmonic_embedding_dir = torch.repeat_interleave(
|
||||||
harmonic_embedding_dir, points_per_ray, dim=0
|
harmonic_embedding_dir, num_points_per_ray, dim=0
|
||||||
)
|
)
|
||||||
|
if non_empty_points is not None:
|
||||||
|
harmonic_embedding_dir = harmonic_embedding_dir[non_empty_points]
|
||||||
|
|
||||||
# total color embedding is concatenation of the harmonic embedding of voxel grid
|
# total color embedding is concatenation of the harmonic embedding of voxel grid
|
||||||
# output and harmonic embedding of the normalized direction
|
# output and harmonic embedding of the normalized direction
|
||||||
@ -505,7 +519,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
for k in range(0, points.shape[-1], chunk_size):
|
for k in range(0, points.shape[-1], chunk_size):
|
||||||
points_in_planes = points[..., k : k + chunk_size]
|
points_in_planes = points[..., k : k + chunk_size]
|
||||||
planes.append(self.get_density(points_in_planes)[..., 0])
|
planes.append(self._get_density(points_in_planes)[..., 0])
|
||||||
|
|
||||||
density_cube = torch.cat(planes, dim=-1)
|
density_cube = torch.cat(planes, dim=-1)
|
||||||
density_cube = torch.nn.functional.max_pool3d(
|
density_cube = torch.nn.functional.max_pool3d(
|
||||||
|
@ -89,7 +89,7 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
|
|||||||
out.append(torch.tensor([[0.0]]))
|
out.append(torch.tensor([[0.0]]))
|
||||||
return torch.cat(out).view(*inshape[:-1], 1).to(device)
|
return torch.cat(out).view(*inshape[:-1], 1).to(device)
|
||||||
|
|
||||||
func.get_density = new_density
|
func._get_density = new_density
|
||||||
func._get_scaffold(0)
|
func._get_scaffold(0)
|
||||||
|
|
||||||
points = torch.tensor(
|
points = torch.tensor(
|
||||||
@ -136,15 +136,15 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
|
|||||||
assert torch.all(scaffold(points)), (scaffold(points), points.shape)
|
assert torch.all(scaffold(points)), (scaffold(points), points.shape)
|
||||||
return points.sum(dim=-1, keepdim=True)
|
return points.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
def new_color(points, camera, directions):
|
def new_color(points, camera, directions, non_empty_points, num_points_per_ray):
|
||||||
# check if all passed points should be passed here
|
# check if all passed points should be passed here
|
||||||
assert torch.all(scaffold(points)) # , (scaffold(points), points)
|
assert torch.all(scaffold(points)) # , (scaffold(points), points)
|
||||||
return points * 2
|
return points * 2
|
||||||
|
|
||||||
# check both computation paths that they contain only points
|
# check both computation paths that they contain only points
|
||||||
# which are not in empty space
|
# which are not in empty space
|
||||||
func.get_density = new_density
|
func._get_density = new_density
|
||||||
func.get_color = new_color
|
func._get_color = new_color
|
||||||
func.voxel_grid_scaffold.forward = scaffold
|
func.voxel_grid_scaffold.forward = scaffold
|
||||||
func._scaffold_ready = True
|
func._scaffold_ready = True
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user