MC rasterize supports heterogeneous bundle; refactoring of bundle-to-padded

Summary:
Rasterize MC was not adapted to heterogeneous bundles.

There are some caveats though:
1) on CO3D, we get up to 18 points per image, which is too few for a reasonable visualisation (see below);
2) rasterising for a batch of 100 is slow.

I also moved the unpacking code close to the bundle to be able to reuse it.

{F789678778}

Reviewed By: bottler, davnov134

Differential Revision: D41008600

fbshipit-source-id: 9f10f1f9f9a174cf8c534b9b9859587d69832b71
This commit is contained in:
Roman Shapovalov 2022-11-07 13:43:31 -08:00 committed by Facebook GitHub Bot
parent 7be49bf46f
commit f3c1e0837c
10 changed files with 210 additions and 111 deletions

View File

@ -9,7 +9,6 @@
# which are part of implicitron. They ensure that the registry is prepopulated.
import logging
import math
import warnings
from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
@ -29,7 +28,7 @@ from pytorch3d.implicitron.tools.config import (
registry,
run_auto_creation,
)
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle
from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import utils as rend_utils
@ -502,9 +501,10 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
preds["images_render"],
preds["depths_render"],
preds["masks_render"],
) = self._rasterize_mc_samples(
ray_bundle.xys,
) = rasterize_sparse_ray_bundle(
ray_bundle,
rendered.features,
(self.render_image_height, self.render_image_width),
rendered.depths,
masks=rendered.masks,
)
@ -828,61 +828,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
return image_rgb, fg_mask, depth_map
@torch.no_grad()
def _rasterize_mc_samples(
self,
xys: torch.Tensor,
features: torch.Tensor,
depth: torch.Tensor,
masks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Rasterizes Monte-Carlo features back onto the image.
Args:
xys: B x ... x 2 2D point locations in PyTorch3D NDC convention
features: B x ... x C tensor containing per-point rendered features.
depth: B x ... x 1 tensor containing per-point rendered depth.
"""
ba = xys.shape[0]
# Flatten the features and xy locations.
features_depth_ras = torch.cat(
(
features.reshape(ba, -1, features.shape[-1]),
depth.reshape(ba, -1, 1),
),
dim=-1,
)
xys_ras = xys.reshape(ba, -1, 2)
if masks is not None:
masks_ras = masks.reshape(ba, -1, 1)
else:
masks_ras = None
if min(self.render_image_height, self.render_image_width) <= 0:
raise ValueError(
"Need to specify a positive"
" self.render_image_height and self.render_image_width"
" for MC rasterisation."
)
# Estimate the rasterization point radius so that we approximately fill
# the whole image given the number of rasterized points.
pt_radius = 2.0 / math.sqrt(xys.shape[1])
# Rasterize the samples.
features_depth_render, masks_render = rasterize_mc_samples(
xys_ras,
features_depth_ras,
(self.render_image_height, self.render_image_width),
radius=pt_radius,
masks=masks_ras,
)
images_render = features_depth_render[:, :-1]
depths_render = features_depth_render[:, -1:]
return images_render, depths_render, masks_render
def _apply_chunked(func, chunk_generator, tensor_collator):
"""
@ -940,7 +885,7 @@ def _chunk_generator(
def _safe_slice(
tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
) -> Optional[torch.Tensor]:
) -> Any:
return tensor[start_idx:end_idx] if tensor is not None else None
for start_idx in iter:

View File

@ -15,6 +15,7 @@ these classes.
"""
import logging
import warnings
from collections.abc import Mapping
from dataclasses import dataclass, field
@ -35,6 +36,9 @@ from pytorch3d.structures.volumes import VolumeLocator
from .utils import interpolate_line, interpolate_plane, interpolate_volume
logger = logging.getLogger(__name__)
@dataclass
class VoxelGridValuesBase:
pass
@ -250,6 +254,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
)
for name, shape in wanted_shapes.items()
}
res = self.get_resolution(epoch)
logger.info(f"Changed grid resolutiuon at epoch {epoch} to {res}")
else:
params = {
name: (
@ -261,6 +267,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
)
for name, tensor in vars(grid_values_with_wanted_resolution).items()
}
# pyre-ignore[29]
return self.values_type(**params), True

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
import warnings
from dataclasses import fields
@ -29,6 +30,9 @@ from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding
logger = logging.getLogger(__name__)
enable_get_default_args(HarmonicEmbedding)
@ -491,6 +495,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
max_indices = tuple(torch.max(non_zero_idxs, dim=0)[0])
min_point, max_point = points[min_indices], points[max_indices]
logger.info(
f"Cropping at epoch {epoch} to bounding box "
f"[{min_point.tolist()}, {max_point.tolist()}]."
)
# crop the voxel grids
self.voxel_grid_density.crop_self(min_point, max_point)
self.voxel_grid_color.crop_self(min_point, max_point)

View File

@ -12,7 +12,7 @@ import torch
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from pytorch3d.implicitron.tools import metric_utils as utils
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.ops import packed_to_padded, padded_to_packed
from pytorch3d.ops import padded_to_packed
from pytorch3d.renderer import utils as rend_utils
from .renderer.base import RendererOutput
@ -257,20 +257,7 @@ class ViewMetrics(ViewMetricsBase):
# memory requirements. Instead of having one image for every element in
# ray_bundle we can than have one image per unique sampled camera.
if ray_bundle.is_packed():
# pyre-ignore[6]
cumsum = torch.cumsum(ray_bundle.camera_counts, dim=0, dtype=torch.long)
first_idxs = torch.cat(
(
# pyre-ignore[16]
ray_bundle.camera_counts.new_zeros((1,), dtype=torch.long),
cumsum[:-1],
)
)
# pyre-ignore[16]
num_inputs = int(ray_bundle.camera_counts.sum())
# pyre-ignore[6]
max_size = int(torch.max(ray_bundle.camera_counts))
xys = packed_to_padded(xys, first_idxs, max_size)
xys, first_idxs, num_inputs = ray_bundle.get_padded_xys()
# reshape the sampling grid as well
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
@ -278,23 +265,26 @@ class ViewMetrics(ViewMetricsBase):
xys = xys.reshape(xys.shape[0], -1, 1, 2)
# closure with the given xys
def sample(tensor, mode):
def sample_full(tensor, mode):
if tensor is None:
return tensor
return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
def sample_packed(tensor, mode):
if tensor is None:
return tensor
# select images that corespond to sampled cameras if raybundle is packed
tensor = tensor[ray_bundle.camera_ids]
if ray_bundle.is_packed():
# select images that corespond to sampled cameras if raybundle is packed
tensor = tensor[ray_bundle.camera_ids]
result = rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
if ray_bundle.is_packed():
# Images after sampling are in a form [batch, 3, max_num_rays, 1],
# packed_to_padded combines first two dimensions so we need to swap 1st
# and 2nd dimension. the result is [n_rays_total_training, 1, 3, 1]
# (we use keepdim=True).
result = result.transpose(1, 2)
result = padded_to_packed(result, first_idxs, num_inputs)[:, None]
result = result.transpose(1, 2)
return padded_to_packed(result, first_idxs, num_inputs, max_size_dim=2)[
:, :, None
] # the result is [n_rays_total_training, 3, 1, 1]
return result
sample = sample_packed if ray_bundle.is_packed() else sample_full
# eval all results in this size
image_rgb = sample(image_rgb, mode="bilinear")

View File

@ -11,10 +11,11 @@ import dataclasses
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
import torch
from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.ops import packed_to_padded
class EvaluationMode(Enum):
@ -37,20 +38,23 @@ class ImplicitronRayBundle:
in the respective 1D coordinate systems; see documentation for
:func:`ray_bundle_to_ray_points` for the conversion formula.
camera_ids: A tensor of shape (N, ) which indicates which camera
was used to sample the rays. `N` is the number of different
sampled cameras.
camera_counts: A tensor of shape (N, ) which how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts)==minibatch`
Ray bundle may represent rays from multiple cameras. In that case, cameras
are stored in the packed form (i.e. rays from the same camera are stored in
the consecutive elements). The following indices will be set:
camera_ids: A tensor of shape (N, ) which indicates which camera
was used to sample the rays. `N` is the number of different
sampled cameras.
camera_counts: A tensor of shape (N, ) which how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts) == minibatch`, where `minibatch = origins.shape[0]`.
"""
origins: torch.Tensor
directions: torch.Tensor
lengths: torch.Tensor
xys: torch.Tensor
camera_ids: Optional[torch.Tensor] = None
camera_counts: Optional[torch.Tensor] = None
camera_ids: Optional[torch.LongTensor] = None
camera_counts: Optional[torch.LongTensor] = None
def is_packed(self) -> bool:
"""
@ -58,6 +62,36 @@ class ImplicitronRayBundle:
"""
return self.camera_ids is not None and self.camera_counts is not None
def get_padded_xys(self) -> Tuple[torch.Tensor, torch.LongTensor, int]:
"""
For a packed ray bundle, returns padded rays. Assumes the input bundle is packed
(i.e. `camera_ids` and `camera_counts` are set).
Returns:
- xys: Tensor of shape (N, max_size, ...) containing the padded
representation of the pixel coordinated;
where max_size is max of `camera_counts`. The values for camera id `i`
will be copied to `xys[i, :]`, with zeros padding out the extra inputs.
- first_idxs: cumulative sum of `camera_counts` defininf the boundaries
between cameras in the packed representation
- num_inputs: the number of cameras in the bundle.
"""
if not self.is_packed():
raise ValueError("get_padded_xys can be called only on a packed bundle")
camera_counts = self.camera_counts
assert camera_counts is not None
cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long)
first_idxs = torch.cat(
(camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1])
)
num_inputs = camera_counts.sum().item()
max_size = torch.max(camera_counts).item()
xys = packed_to_padded(self.xys, first_idxs, max_size)
# pyre-ignore [7] pytorch typeshed inaccuracy
return xys, first_idxs, num_inputs
@dataclass
class RendererOutput:

View File

@ -4,15 +4,96 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional, Tuple
import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.ops import packed_to_padded
from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.structures import Pointclouds
from .point_cloud_utils import render_point_cloud_pytorch3d
@torch.no_grad()
def rasterize_sparse_ray_bundle(
ray_bundle: ImplicitronRayBundle,
features: torch.Tensor,
image_size_hw: Tuple[int, int],
depth: torch.Tensor,
masks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Rasterizes sparse features corresponding to the coordinates defined by
the rays in the bundle.
Args:
ray_bundle: ray bundle object with B x ... x 2 pixel coordinates,
it can be packed.
features: B x ... x C tensor containing per-point rendered features.
image_size_hw: Tuple[image_height, image_width] containing
the size of rasterized image.
depth: B x ... x 1 tensor containing per-point rendered depth.
masks: B x ... x 1 tensor containing the alpha mask of the
rendered features.
Returns:
- image_render: B x C x H x W tensor of rasterized features
- depths_render: B x 1 x H x W tensor of rasterized depth maps
- masks_render: B x 1 x H x W tensor of opacities after splatting
"""
# Flatten the features and xy locations.
features_depth_ras = torch.cat(
(features.flatten(1, -2), depth.flatten(1, -2)), dim=-1
)
xys = ray_bundle.xys
masks_ras = None
if ray_bundle.is_packed():
camera_counts = ray_bundle.camera_counts
assert camera_counts is not None
xys, first_idxs, _ = ray_bundle.get_padded_xys()
masks_ras = (
torch.arange(xys.shape[1], device=xys.device)[:, None]
< camera_counts[:, None, None]
)
max_size = torch.max(camera_counts).item()
features_depth_ras = packed_to_padded(
features_depth_ras[:, 0], first_idxs, max_size
)
if masks is not None:
padded_mask = packed_to_padded(masks.flatten(1, -1), first_idxs, max_size)
masks_ras = padded_mask * masks_ras
xys_ras = xys.flatten(1, -2)
if masks_ras is None:
assert not ray_bundle.is_packed()
masks_ras = masks.flatten(1, -2) if masks is not None else None
if min(*image_size_hw) <= 0:
raise ValueError(
"Need to specify a positive output_size_hw for bundle rasterisation."
)
# Estimate the rasterization point radius so that we approximately fill
# the whole image given the number of rasterized points.
pt_radius = 2.0 / math.sqrt(xys.shape[1])
# Rasterize the samples.
features_depth_render, masks_render = rasterize_mc_samples(
xys_ras,
features_depth_ras,
image_size_hw,
radius=pt_radius,
masks=masks_ras,
)
images_render = features_depth_render[:, :-1]
depths_render = features_depth_render[:, -1:]
return images_render, depths_render, masks_render
def rasterize_mc_samples(
xys: torch.Tensor,
feats: torch.Tensor,

View File

@ -60,7 +60,9 @@ class _PackedToPadded(Function):
return grad_input, None, None
def packed_to_padded(inputs, first_idxs, max_size):
def packed_to_padded(
inputs: torch.Tensor, first_idxs: torch.LongTensor, max_size: int
) -> torch.Tensor:
"""
Torch wrapper that handles allowed input shapes. See description below.
@ -74,7 +76,7 @@ def packed_to_padded(inputs, first_idxs, max_size):
Returns:
inputs_padded: FloatTensor of shape (N, max_size) or (N, max_size, ...)
where max_size is max of `sizes`. The values for batch element i
where max_size is max of `sizes`. The values for batch element i
which start at `inputs[first_idxs[i]]` will be copied to
`inputs_padded[i, :]`, with zeros padding out the extra inputs.
@ -89,6 +91,7 @@ def packed_to_padded(inputs, first_idxs, max_size):
inputs = inputs.unsqueeze(1)
else:
inputs = inputs.reshape(input_shape[0], -1)
# pyre-ignore [16]
inputs_padded = _PackedToPadded.apply(inputs, first_idxs, max_size)
# if flat is True, reshape output to (N, max_size) from (N, max_size, 1)
# else reshape output to (N, max_size, ...)
@ -147,39 +150,49 @@ class _PaddedToPacked(Function):
return grad_input, None, None
def padded_to_packed(inputs, first_idxs, num_inputs):
def padded_to_packed(
inputs: torch.Tensor,
first_idxs: torch.LongTensor,
num_inputs: int,
max_size_dim: int = 1,
) -> torch.Tensor:
"""
Torch wrapper that handles allowed input shapes. See description below.
Args:
inputs: FloatTensor of shape (N, max_size) or (N, max_size, ...),
inputs: FloatTensor of shape (N, ..., max_size) or (N, ..., max_size, ...),
representing the padded tensor, e.g. areas for faces in a batch of
meshes.
meshes, where max_size occurs on max_size_dim-th position.
first_idxs: LongTensor of shape (N,) where N is the number of
elements in the batch and `first_idxs[i] = f`
means that the inputs for batch element i begin at `inputs_packed[f]`.
num_inputs: Number of packed entries (= F)
max_size_dim: the dimension to be packed
Returns:
inputs_packed: FloatTensor of shape (F,) or (F, ...) where
`inputs_packed[first_idx[i]:first_idx[i+1]] = inputs[i, :]`.
`inputs_packed[first_idx[i]:first_idx[i+1]] = inputs[i, ..., :delta[i]]`,
where `delta[i] = first_idx[i+1] - first_idx[i]`.
To handle the allowed input shapes, we convert the inputs tensor of shape
(N, max_size) to (N, max_size, 1). We reshape the output back to (F,) from
(N, max_size) to (N, max_size, 1). We reshape the output back to (F,) from
(F, 1).
"""
n_dims = inputs.dim()
# move the variable dim to position 1
inputs = inputs.movedim(max_size_dim, 1)
# if inputs is of shape (N, max_size), reshape into (N, max_size, 1))
input_shape = inputs.shape
n_dims = inputs.dim()
if n_dims == 2:
inputs = inputs.unsqueeze(2)
else:
inputs = inputs.reshape(*input_shape[:2], -1)
# pyre-ignore [16]
inputs_packed = _PaddedToPacked.apply(inputs, first_idxs, num_inputs)
# if input is flat, reshape output to (F,) from (F, 1)
# else reshape output to (F, ...)
if n_dims == 2:
return inputs_packed.squeeze(1)
if n_dims == 3:
return inputs_packed
return inputs_packed.view(-1, *input_shape[2:])

View File

@ -207,7 +207,8 @@ class MultinomialRaysampler(torch.nn.Module):
n_rays_per_image,
) = _sample_cameras_and_masks(n_rays_total, cameras, mask)
else:
camera_ids = torch.arange(len(cameras), dtype=torch.long)
# pyre-ignore[9]
camera_ids: torch.LongTensor = torch.arange(len(cameras), dtype=torch.long)
batch_size = cameras.R.shape[0]
device = cameras.device
@ -438,7 +439,8 @@ class MonteCarloRaysampler(torch.nn.Module):
n_rays_per_image,
) = _sample_cameras_and_masks(self._n_rays_total, cameras, None)
else:
camera_ids = torch.arange(len(cameras), dtype=torch.long)
# pyre-ignore[9]
camera_ids: torch.LongTensor = torch.arange(len(cameras), dtype=torch.long)
n_rays_per_image = self._n_rays_per_image
batch_size = cameras.R.shape[0]
@ -716,7 +718,11 @@ def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor:
def _sample_cameras_and_masks(
n_samples: int, cameras: CamerasBase, mask: Optional[torch.Tensor] = None
) -> Tuple[
CamerasBase, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor
CamerasBase,
Optional[torch.Tensor],
torch.LongTensor,
torch.LongTensor,
torch.LongTensor,
]:
"""
Samples n_rays_total cameras and masks and returns them in a form
@ -740,6 +746,7 @@ def _sample_cameras_and_masks(
dtype=torch.long,
)
unique_ids, counts = torch.unique(sampled_ids, return_counts=True)
# pyre-ignore[7]
return (
cameras[unique_ids],
mask[unique_ids] if mask is not None else None,
@ -749,8 +756,9 @@ def _sample_cameras_and_masks(
)
# TODO: this function can be unified with ImplicitronRayBundle.get_padded_xys
def _pack_ray_bundle(
ray_bundle: RayBundle, camera_ids: torch.Tensor, camera_counts: torch.Tensor
ray_bundle: RayBundle, camera_ids: torch.LongTensor, camera_counts: torch.LongTensor
) -> HeterogeneousRayBundle:
"""
Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to
@ -765,9 +773,11 @@ def _pack_ray_bundle(
Returns:
HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1
"""
# pyre-ignore[9]
camera_counts = camera_counts.to(ray_bundle.origins.device)
cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long)
first_idxs = torch.cat(
# pyre-ignore[9]
first_idxs: torch.LongTensor = torch.cat(
(camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1])
)
num_inputs = int(camera_counts.sum())

View File

@ -60,8 +60,8 @@ class HeterogeneousRayBundle:
directions: torch.Tensor
lengths: torch.Tensor
xys: torch.Tensor
camera_ids: Optional[torch.Tensor] = None
camera_counts: Optional[torch.Tensor] = None
camera_ids: Optional[torch.LongTensor] = None
camera_counts: Optional[torch.LongTensor] = None
def ray_bundle_to_ray_points(

View File

@ -188,6 +188,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
# check forward
self.assertClose(values_packed, values_packed_torch)
if len(dims) > 0:
values_packed_dim2 = padded_to_packed(
values.transpose(1, 2),
mesh_to_faces_packed_first_idx,
num_faces_per_mesh.sum().item(),
max_size_dim=2,
)
# check forward
self.assertClose(values_packed_dim2, values_packed_torch)
# check backward
if len(dims) == 0:
grad_inputs = torch.rand((num_faces_per_mesh.sum().item()), device=device)