Accumulate points (#4)

Summary:
Code for accumulating points in the z-buffer in three ways:
1. weighted sum
2. normalised weighted sum
3. alpha compositing

Pull Request resolved: https://github.com/fairinternal/pytorch3d/pull/4

Reviewed By: nikhilaravi

Differential Revision: D20522422

Pulled By: gkioxari

fbshipit-source-id: 5023baa05f15e338f3821ef08f5552c2dcbfc06c
This commit is contained in:
Olivia
2020-03-19 11:19:39 -07:00
committed by Facebook GitHub Bot
parent 5218f45c2c
commit 53599770dd
21 changed files with 2466 additions and 4 deletions

View File

@@ -1 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .compositor import AlphaCompositor, NormWeightedCompositor
from .rasterize_points import rasterize_points
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
from .renderer import PointsRenderer
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@@ -0,0 +1,51 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
from ..compositing import CompositeParams, alpha_composite, norm_weighted_sum
# A compositor should take as input 3D points and some corresponding information.
# Given this information, the compositor can:
# - blend colors across the top K vertices at a pixel
class AlphaCompositor(nn.Module):
"""
Accumulate points using alpha compositing.
"""
def __init__(self, composite_params=None):
super().__init__()
self.composite_params = (
composite_params
if composite_params is not None
else CompositeParams()
)
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
images = alpha_composite(
fragments, alphas, ptclds, self.composite_params
)
return images
class NormWeightedCompositor(nn.Module):
"""
Accumulate points using a normalized weighted sum.
"""
def __init__(self, composite_params=None):
super().__init__()
self.composite_params = (
composite_params
if composite_params is not None
else CompositeParams()
)
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
images = norm_weighted_sum(
fragments, alphas, ptclds, self.composite_params
)
return images

View File

@@ -0,0 +1,103 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple, Optional
import torch
import torch.nn as nn
from ..cameras import get_world_to_view_transform
from .rasterize_points import rasterize_points
# Class to store the outputs of point rasterization
class PointFragments(NamedTuple):
idx: torch.Tensor
zbuf: torch.Tensor
dists: torch.Tensor
# Class to store the point rasterization params with defaults
class PointsRasterizationSettings(NamedTuple):
image_size: int = 256
radius: float = 0.01
points_per_pixel: int = 8
bin_size: Optional[int] = None
max_points_per_bin: Optional[int] = None
class PointsRasterizer(nn.Module):
"""
This class implements methods for rasterizing a batch of pointclouds.
"""
def __init__(self, cameras, raster_settings=None):
"""
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
world-to-view and view-to-screen
transformations.
raster_settings: the parameters for rasterization. This should be a
named tuple.
All these initial settings can be overridden by passing keyword
arguments to the forward function.
"""
super().__init__()
if raster_settings is None:
raster_settings = PointsRasterizationSettings()
self.cameras = cameras
self.raster_settings = raster_settings
def transform(self, point_clouds, **kwargs) -> torch.Tensor:
"""
Args:
point_clouds: a set of point clouds
Returns:
points_screen: the points with the vertex positions in screen
space
NOTE: keeping this as a separate function for readability but it could
be moved into forward.
"""
cameras = kwargs.get("cameras", self.cameras)
pts_world = point_clouds.points_padded()
pts_world_packed = point_clouds.points_packed()
pts_screen = cameras.transform_points(pts_world, **kwargs)
# NOTE: Retaining view space z coordinate for now.
# TODO: Remove this line when the convention for the z coordinate in
# the rasterizer is decided. i.e. retain z in view space or transform
# to a different range.
view_transform = get_world_to_view_transform(R=cameras.R, T=cameras.T)
verts_view = view_transform.transform_points(pts_world)
pts_screen[..., 2] = verts_view[..., 2]
# Offset points of input pointcloud to reuse cached padded/packed calculations.
pad_to_packed_idx = point_clouds.padded_to_packed_idx()
pts_screen_packed = pts_screen.view(-1, 3)[pad_to_packed_idx, :]
pts_packed_offset = pts_screen_packed - pts_world_packed
point_clouds = point_clouds.offset(pts_packed_offset)
return point_clouds
def forward(self, point_clouds, **kwargs) -> PointFragments:
"""
Args:
point_clouds: a set of point clouds with coordinates in world space.
Returns:
PointFragments: Rasterization outputs as a named tuple.
"""
points_screen = self.transform(point_clouds, **kwargs)
raster_settings = kwargs.get("raster_settings", self.raster_settings)
idx, zbuf, dists2 = rasterize_points(
points_screen,
image_size=raster_settings.image_size,
radius=raster_settings.radius,
points_per_pixel=raster_settings.points_per_pixel,
bin_size=raster_settings.bin_size,
max_points_per_bin=raster_settings.max_points_per_bin,
)
return PointFragments(idx=idx, zbuf=zbuf, dists=dists2)

View File

@@ -0,0 +1,56 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
# A renderer class should be initialized with a
# function for rasterization and a function for compositing.
# The rasterizer should:
# - transform inputs from world -> screen space
# - rasterize inputs
# - return fragments
# The compositor can take fragments as input along with any other properties of
# the scene and generate images.
# E.g. rasterize inputs and then shade
#
# fragments = self.rasterize(point_clouds)
# images = self.compositor(fragments, point_clouds)
# return images
class PointsRenderer(nn.Module):
"""
A class for rendering a batch of points. The class should
be initialized with a rasterizer and compositor class which each have a forward
function.
"""
def __init__(self, rasterizer, compositor):
super().__init__()
self.rasterizer = rasterizer
self.compositor = compositor
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
fragments = self.rasterizer(point_clouds, **kwargs)
# Construct weights based on the distance of a point to the true point.
# However, this could be done differently: e.g. predicted as opposed
# to a function of the weights.
r = self.rasterizer.raster_settings.radius
dists2 = fragments.dists.permute(0, 3, 1, 2)
weights = 1 - dists2 / (r * r)
images = self.compositor(
fragments.idx.long().permute(0, 3, 1, 2),
weights,
point_clouds.features_packed().permute(1, 0),
**kwargs
)
# permute so image comes at the end
images = images.permute(0, 2, 3, 1)
return images