mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	voxel_grid_implicit_function scaffold fixes
Summary: Fix indexing of directions after filtering of points by scaffold. Reviewed By: shapovalov Differential Revision: D40853482 fbshipit-source-id: 9cfdb981e97cb82edcd27632c5848537ed2c6837
This commit is contained in:
		
							parent
							
								
									e4a3298149
								
							
						
					
					
						commit
						a1f2ded58a
					
				@ -7,7 +7,7 @@
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from dataclasses import fields
 | 
			
		||||
from typing import Callable, Dict, Optional, Tuple, Union
 | 
			
		||||
from typing import Callable, Dict, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
@ -118,11 +118,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
            the calculation.)
 | 
			
		||||
        scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
 | 
			
		||||
            voxel grid which stores scaffold
 | 
			
		||||
        scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than
 | 
			
		||||
        scaffold_empty_space_threshold (float): if `self._get_density` evaluates to less than
 | 
			
		||||
            this it will be considered as empty space and the scaffold at that point would
 | 
			
		||||
            evaluate as empty space.
 | 
			
		||||
        scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
 | 
			
		||||
            at the same time. To calculate the scaffold we need to query `get_density()` at
 | 
			
		||||
            at the same time. To calculate the scaffold we need to query `_get_density()` at
 | 
			
		||||
            every voxel, this calculation can be split into scaffold depth number of xy plane
 | 
			
		||||
            calculations if you want the lowest memory usage, one calculation to calculate the
 | 
			
		||||
            whole scaffold, but with higher memory footprint or any other number of planes.
 | 
			
		||||
@ -242,14 +242,16 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        points = ray_bundle_to_ray_points(ray_bundle)
 | 
			
		||||
        directions = ray_bundle.directions.reshape(-1, 3)
 | 
			
		||||
        input_shape = points.shape
 | 
			
		||||
        num_points_per_ray = input_shape[-2]
 | 
			
		||||
        points = points.view(-1, 3)
 | 
			
		||||
        non_empty_points = None
 | 
			
		||||
 | 
			
		||||
        # ########## filter the points using the scaffold ########## #
 | 
			
		||||
        if self._scaffold_ready and self.scaffold_filter_points:
 | 
			
		||||
            # pyre-ignore[29]
 | 
			
		||||
            non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
                # pyre-ignore[29]
 | 
			
		||||
                non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
 | 
			
		||||
            points = points[non_empty_points]
 | 
			
		||||
            directions = directions[non_empty_points]
 | 
			
		||||
            if len(points) == 0:
 | 
			
		||||
                warnings.warn(
 | 
			
		||||
                    "The scaffold has filtered all the points."
 | 
			
		||||
@ -262,8 +264,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # ########## calculate color and density ########## #
 | 
			
		||||
        rays_densities, rays_colors = self.calculate_density_and_color(
 | 
			
		||||
            points, directions, camera
 | 
			
		||||
        rays_densities, rays_colors = self._calculate_density_and_color(
 | 
			
		||||
            points, directions, camera, non_empty_points, num_points_per_ray
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if not (self._scaffold_ready and self.scaffold_filter_points):
 | 
			
		||||
@ -283,9 +285,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        rays_colors_combined = rays_colors.new_zeros(
 | 
			
		||||
            (math.prod(input_shape[:-1]), rays_colors.shape[-1])
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore[61]
 | 
			
		||||
        assert non_empty_points is not None
 | 
			
		||||
        rays_densities_combined[non_empty_points] = rays_densities
 | 
			
		||||
        # pyre-ignore[61]
 | 
			
		||||
        rays_colors_combined[non_empty_points] = rays_colors
 | 
			
		||||
 | 
			
		||||
        return (
 | 
			
		||||
@ -294,11 +295,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
            {},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def calculate_density_and_color(
 | 
			
		||||
    def _calculate_density_and_color(
 | 
			
		||||
        self,
 | 
			
		||||
        points: torch.Tensor,
 | 
			
		||||
        directions: torch.Tensor,
 | 
			
		||||
        camera: Optional[CamerasBase] = None,
 | 
			
		||||
        camera: Optional[CamerasBase],
 | 
			
		||||
        non_empty_points: Optional[torch.Tensor],
 | 
			
		||||
        num_points_per_ray: int,
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates density and color at `points`.
 | 
			
		||||
@ -306,11 +309,14 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            points: points at which to calculate density and color.
 | 
			
		||||
                Tensor of shape [..., 3].
 | 
			
		||||
            directions: from which directions are the points viewed
 | 
			
		||||
                Tensor of shape [..., 3].
 | 
			
		||||
                Tensor of shape [n_points, 3].
 | 
			
		||||
            directions: from which directions are the points viewed.
 | 
			
		||||
                One per ray. Tensor of shape [n_rays, 3].
 | 
			
		||||
            camera: A camera model which will be used to transform the viewing
 | 
			
		||||
                directions
 | 
			
		||||
            non_empty_points: indices of points which weren't filtered out;
 | 
			
		||||
                used for expanding directions
 | 
			
		||||
            num_points_per_ray: number of points per ray, needed to expand directions.
 | 
			
		||||
        Returns:
 | 
			
		||||
               Tuple of color (tensor of shape [..., 3]) and density
 | 
			
		||||
                (tensor of shape [..., 1])
 | 
			
		||||
@ -323,20 +329,24 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
            with torch.cuda.stream(other_stream):
 | 
			
		||||
                # rays_densities.shape =
 | 
			
		||||
                # [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
 | 
			
		||||
                rays_densities = self.get_density(points)
 | 
			
		||||
                rays_densities = self._get_density(points)
 | 
			
		||||
 | 
			
		||||
            # rays_colors.shape =
 | 
			
		||||
            # [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
 | 
			
		||||
            rays_colors = self.get_color(points, camera, directions)
 | 
			
		||||
            rays_colors = self._get_color(
 | 
			
		||||
                points, camera, directions, non_empty_points, num_points_per_ray
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            current_stream.wait_stream(other_stream)
 | 
			
		||||
        else:
 | 
			
		||||
            # Same calculation as above, just serial.
 | 
			
		||||
            rays_densities = self.get_density(points)
 | 
			
		||||
            rays_colors = self.get_color(points, camera, directions)
 | 
			
		||||
            rays_densities = self._get_density(points)
 | 
			
		||||
            rays_colors = self._get_color(
 | 
			
		||||
                points, camera, directions, non_empty_points, num_points_per_ray
 | 
			
		||||
            )
 | 
			
		||||
        return rays_densities, rays_colors
 | 
			
		||||
 | 
			
		||||
    def get_density(self, points: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    def _get_density(self, points: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates density at points:
 | 
			
		||||
            1) Evaluates the voxel grid on points
 | 
			
		||||
@ -356,11 +366,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        # shape = [..., density_dim]
 | 
			
		||||
        return self.decoder_density(harmonic_embedding_density)
 | 
			
		||||
 | 
			
		||||
    def get_color(
 | 
			
		||||
    def _get_color(
 | 
			
		||||
        self,
 | 
			
		||||
        points: torch.Tensor,
 | 
			
		||||
        camera: Optional[CamerasBase],
 | 
			
		||||
        directions: torch.Tensor,
 | 
			
		||||
        non_empty_points: Optional[torch.Tensor],
 | 
			
		||||
        num_points_per_ray: int,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Calculates color at points using the viewing direction:
 | 
			
		||||
@ -376,6 +388,9 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
                directions
 | 
			
		||||
            directions: A tensor of shape `(..., 3)`
 | 
			
		||||
                containing the direction vectors of sampling rays in world coords.
 | 
			
		||||
            non_empty_points: indices of points which weren't filtered out;
 | 
			
		||||
                used for expanding directions
 | 
			
		||||
            num_points_per_ray: number of points per ray, needed to expand directions.
 | 
			
		||||
        """
 | 
			
		||||
        # ########## transform direction ########## #
 | 
			
		||||
        if self.xyz_ray_dir_in_camera_coords:
 | 
			
		||||
@ -400,12 +415,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
            rays_directions_normed
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        n_rays = directions.shape[0]
 | 
			
		||||
        points_per_ray: int = points.shape[0] // n_rays
 | 
			
		||||
 | 
			
		||||
        harmonic_embedding_dir = torch.repeat_interleave(
 | 
			
		||||
            harmonic_embedding_dir, points_per_ray, dim=0
 | 
			
		||||
            harmonic_embedding_dir, num_points_per_ray, dim=0
 | 
			
		||||
        )
 | 
			
		||||
        if non_empty_points is not None:
 | 
			
		||||
            harmonic_embedding_dir = harmonic_embedding_dir[non_empty_points]
 | 
			
		||||
 | 
			
		||||
        # total color embedding is concatenation of the harmonic embedding of voxel grid
 | 
			
		||||
        # output and harmonic embedding of the normalized direction
 | 
			
		||||
@ -505,7 +519,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        )
 | 
			
		||||
        for k in range(0, points.shape[-1], chunk_size):
 | 
			
		||||
            points_in_planes = points[..., k : k + chunk_size]
 | 
			
		||||
            planes.append(self.get_density(points_in_planes)[..., 0])
 | 
			
		||||
            planes.append(self._get_density(points_in_planes)[..., 0])
 | 
			
		||||
 | 
			
		||||
        density_cube = torch.cat(planes, dim=-1)
 | 
			
		||||
        density_cube = torch.nn.functional.max_pool3d(
 | 
			
		||||
 | 
			
		||||
@ -89,7 +89,7 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                    out.append(torch.tensor([[0.0]]))
 | 
			
		||||
            return torch.cat(out).view(*inshape[:-1], 1).to(device)
 | 
			
		||||
 | 
			
		||||
        func.get_density = new_density
 | 
			
		||||
        func._get_density = new_density
 | 
			
		||||
        func._get_scaffold(0)
 | 
			
		||||
 | 
			
		||||
        points = torch.tensor(
 | 
			
		||||
@ -136,15 +136,15 @@ class TestVoxelGridImplicitFunction(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            assert torch.all(scaffold(points)), (scaffold(points), points.shape)
 | 
			
		||||
            return points.sum(dim=-1, keepdim=True)
 | 
			
		||||
 | 
			
		||||
        def new_color(points, camera, directions):
 | 
			
		||||
        def new_color(points, camera, directions, non_empty_points, num_points_per_ray):
 | 
			
		||||
            # check if all passed points should be passed here
 | 
			
		||||
            assert torch.all(scaffold(points))  # , (scaffold(points), points)
 | 
			
		||||
            return points * 2
 | 
			
		||||
 | 
			
		||||
        # check both computation paths that they contain only points
 | 
			
		||||
        # which are not in empty space
 | 
			
		||||
        func.get_density = new_density
 | 
			
		||||
        func.get_color = new_color
 | 
			
		||||
        func._get_density = new_density
 | 
			
		||||
        func._get_color = new_color
 | 
			
		||||
        func.voxel_grid_scaffold.forward = scaffold
 | 
			
		||||
        func._scaffold_ready = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user