Roman Shapovalov a999fc22ee Type safety fixes
Summary: Pyre expects Mapping for ** operator.

Reviewed By: bottler

Differential Revision: D35288632

fbshipit-source-id: 34d6f26ad912b3a5046f440922bb6ed2fd86f533
2022-04-01 04:24:46 -07:00

141 lines
5.4 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple, Union
import torch
import torch.nn as nn
from pytorch3d.structures import Pointclouds
from .rasterize_points import rasterize_points
# Class to store the outputs of point rasterization
class PointFragments(NamedTuple):
idx: torch.Tensor
zbuf: torch.Tensor
dists: torch.Tensor
@dataclass
class PointsRasterizationSettings:
"""
Class to store the point rasterization params with defaults
Members:
image_size: Either common height and width or (height, width), in pixels.
radius: The radius (in NDC units) of each disk to be rasterized.
This can either be a float in which case the same radius is used
for each point, or a torch.Tensor of shape (N, P) giving a radius
per point in the batch.
points_per_pixel: (int) Number of points to keep track of per pixel.
We return the nearest points_per_pixel points along the z-axis.
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
bin_size=0 uses naive rasterization; setting bin_size=None attempts
to set it heuristically based on the shape of the input. This should
not affect the output, but can affect the speed of the forward pass.
max_points_per_bin: Only applicable when using coarse-to-fine
rasterization (bin_size != 0); this is the maximum number of points
allowed within each bin. This should not affect the output values,
but can affect the memory usage in the forward pass.
Setting max_points_per_bin=None attempts to set with a heuristic.
"""
image_size: Union[int, Tuple[int, int]] = 256
radius: Union[float, torch.Tensor] = 0.01
points_per_pixel: int = 8
bin_size: Optional[int] = None
max_points_per_bin: Optional[int] = None
class PointsRasterizer(nn.Module):
"""
This class implements methods for rasterizing a batch of pointclouds.
"""
def __init__(self, cameras=None, raster_settings=None) -> None:
"""
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
world-to-view and view-to-ndc transformations.
raster_settings: the parameters for rasterization. This should be a
named tuple.
All these initial settings can be overridden by passing keyword
arguments to the forward function.
"""
super().__init__()
if raster_settings is None:
raster_settings = PointsRasterizationSettings()
self.cameras = cameras
self.raster_settings = raster_settings
def transform(self, point_clouds, **kwargs) -> Pointclouds:
"""
Args:
point_clouds: a set of point clouds
Returns:
points_proj: the points with positions projected
in NDC space
NOTE: keeping this as a separate function for readability but it could
be moved into forward.
"""
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of PointsRasterizer"
raise ValueError(msg)
pts_world = point_clouds.points_padded()
# NOTE: Retaining view space z coordinate for now.
# TODO: Remove this line when the convention for the z coordinate in
# the rasterizer is decided. i.e. retain z in view space or transform
# to a different range.
eps = kwargs.get("eps", None)
pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
pts_world, eps=eps
)
# view to NDC transform
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
projection_transform = cameras.get_projection_transform(**kwargs).compose(
to_ndc_transform
)
pts_ndc = projection_transform.transform_points(pts_view, eps=eps)
pts_ndc[..., 2] = pts_view[..., 2]
point_clouds = point_clouds.update_padded(pts_ndc)
return point_clouds
def to(self, device):
# Manually move to device cameras as it is not a subclass of nn.Module
if self.cameras is not None:
self.cameras = self.cameras.to(device)
return self
def forward(self, point_clouds, **kwargs) -> PointFragments:
"""
Args:
point_clouds: a set of point clouds with coordinates in world space.
Returns:
PointFragments: Rasterization outputs as a named tuple.
"""
points_proj = self.transform(point_clouds, **kwargs)
raster_settings = kwargs.get("raster_settings", self.raster_settings)
idx, zbuf, dists2 = rasterize_points(
points_proj,
image_size=raster_settings.image_size,
radius=raster_settings.radius,
points_per_pixel=raster_settings.points_per_pixel,
bin_size=raster_settings.bin_size,
max_points_per_bin=raster_settings.max_points_per_bin,
)
return PointFragments(idx=idx, zbuf=zbuf, dists=dists2)