Non Square image rasterization for pointclouds

Summary:
Similar to non square image rasterization for meshes, apply the same updates to the pointcloud rasterizer.

Main API Change:
- PointRasterizationSettings now accepts a tuple/list of (H, W) for the image size.

Reviewed By: jcjohnson

Differential Revision: D25465206

fbshipit-source-id: 7370d83c431af1b972158cecae19d82364623380
This commit is contained in:
Nikhila Ravi
2020-12-15 14:14:27 -08:00
committed by Facebook GitHub Bot
parent 569e5229a9
commit 3d769a66cb
22 changed files with 712 additions and 263 deletions

View File

@@ -75,7 +75,7 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
pixels with accumulated features have unchanged values.
"""
# Initialize background mask
background_mask = pix_idxs[:, 0] < 0 # (N, image_size, image_size)
background_mask = pix_idxs[:, 0] < 0 # (N, H, W)
# Convert background_color to an appropriate tensor and check shape
if not torch.is_tensor(background_color):

View File

@@ -6,7 +6,7 @@ import torch
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
from pytorch3d import _C
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_non_square_ndc
# Maxinum number of faces per bins for
@@ -14,17 +14,30 @@ from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
kMaxPointsPerBin = 22
# TODO(jcjohns): Support non-square images
def rasterize_points(
pointclouds,
image_size: int = 256,
image_size: Union[int, List[int], Tuple[int, int]] = 256,
radius: Union[float, List, Tuple, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
bin_size: Optional[int] = None,
max_points_per_bin: Optional[int] = None,
):
"""
Pointcloud rasterization
Each pointcloud is rasterized onto a separate image of shape
(H, W) if `image_size` is a tuple or (image_size, image_size) if it
is an int.
If the desired image size is non square (i.e. a tuple of (H, W) where H != W)
the aspect ratio needs special consideration. There are two aspect ratios
to be aware of:
- the aspect ratio of each pixel
- the aspect ratio of the output image
The camera can be used to set the pixel aspect ratio. In the rasterizer,
we assume square pixels, but variable image aspect ratio (i.e rectangle images).
In most cases you will want to set the camera aspect ratio to
1.0 (i.e. square pixels) and only vary the
`image_size` (i.e. the output image dimensions in pix
Args:
pointclouds: A Pointclouds object representing a batch of point clouds to be
@@ -34,7 +47,8 @@ def rasterize_points(
be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at
(0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left,
the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front.
image_size: Integer giving the resolution of the rasterized image
image_size: Size in pixels of the output image to be rasterized.
Can optionally be a tuple of (H, W) in the case of non square images.
radius (Optional): The radius (in NDC units) of the disk to
be rasterized. This can either be a float in which case the same radius is used
for each point, or a torch.Tensor of shape (N, P) giving a radius per point
@@ -71,6 +85,9 @@ def rasterize_points(
then `dists[n, y, x, k]` is the squared distance between the pixel (y, x)
and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer
than points_per_pixel are padded with -1.
In the case that image_size is a tuple of (H, W) then the outputs
will be of shape `(N, H, W, ...)`.
"""
points_packed = pointclouds.points_packed()
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
@@ -78,26 +95,46 @@ def rasterize_points(
radius = _format_radius(radius, pointclouds)
# In the case that H != W use the max image size to set the bin_size
# to accommodate the num bins constraint in the coarse rasteizer.
# If the ratio of H:W is large this might cause issues as the smaller
# dimension will have fewer bins.
# TODO: consider a better way of setting the bin size.
if isinstance(image_size, (tuple, list)):
if len(image_size) != 2:
raise ValueError("Image size can only be a tuple/list of (H, W)")
if not all(i > 0 for i in image_size):
raise ValueError(
"Image sizes must be greater than 0; got %d, %d" % image_size
)
if not all(type(i) == int for i in image_size):
raise ValueError("Image sizes must be integers; got %f, %f" % image_size)
max_image_size = max(*image_size)
im_size = image_size
else:
im_size = (image_size, image_size)
max_image_size = image_size
if bin_size is None:
if not points_packed.is_cuda:
# Binned CPU rasterization not fully implemented
bin_size = 0
else:
# TODO: These heuristics are not well-thought out!
if image_size <= 64:
if max_image_size <= 64:
bin_size = 8
elif image_size <= 256:
elif max_image_size <= 256:
bin_size = 16
elif image_size <= 512:
elif max_image_size <= 512:
bin_size = 32
elif image_size <= 1024:
elif max_image_size <= 1024:
bin_size = 64
if bin_size != 0:
# There is a limit on the number of points per bin in the cuda kernel.
# pyre-fixme[58]: `//` is not supported for operand types `int` and
# `Union[int, None, int]`.
points_per_bin = 1 + (image_size - 1) // bin_size
points_per_bin = 1 + (max_image_size - 1) // bin_size
if points_per_bin >= kMaxPointsPerBin:
raise ValueError(
"bin_size too small, number of points per bin must be less than %d; got %d"
@@ -114,7 +151,7 @@ def rasterize_points(
points_packed,
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size,
im_size,
radius,
points_per_pixel,
bin_size,
@@ -173,7 +210,7 @@ class _RasterizePoints(torch.autograd.Function):
points, # (P, 3)
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size: int = 256,
image_size: Union[List[int], Tuple[int, int]] = (256, 256),
radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
bin_size: int = 0,
@@ -225,7 +262,7 @@ class _RasterizePoints(torch.autograd.Function):
def rasterize_points_python(
pointclouds,
image_size: int = 256,
image_size: Union[int, Tuple[int, int]] = 256,
radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
):
@@ -235,7 +272,12 @@ def rasterize_points_python(
Inputs / Outputs: Same as above
"""
N = len(pointclouds)
S, K = image_size, points_per_pixel
H, W = (
image_size
if isinstance(image_size, (tuple, list))
else (image_size, image_size)
)
K = points_per_pixel
device = pointclouds.device
points_packed = pointclouds.points_packed()
@@ -247,11 +289,11 @@ def rasterize_points_python(
# Intialize output tensors.
point_idxs = torch.full(
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
(N, H, W, K), fill_value=-1, dtype=torch.int32, device=device
)
zbuf = torch.full((N, S, S, K), fill_value=-1, dtype=torch.float32, device=device)
zbuf = torch.full((N, H, W, K), fill_value=-1, dtype=torch.float32, device=device)
pix_dists = torch.full(
(N, S, S, K), fill_value=-1, dtype=torch.float32, device=device
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
)
# NDC is from [-1, 1]. Get pixel size using specified image size.
@@ -263,18 +305,18 @@ def rasterize_points_python(
point_stop_idx = point_start_idx + num_points_per_cloud[n]
# Iterate through the horizontal lines of the image from top to bottom.
for yi in range(S):
for yi in range(H):
# Y coordinate of one end of the image. Reverse the ordering
# of yi so that +Y is pointing up in the image.
yfix = S - 1 - yi
yf = pix_to_ndc(yfix, S)
yfix = H - 1 - yi
yf = pix_to_non_square_ndc(yfix, H, W)
# Iterate through pixels on this horizontal line, left to right.
for xi in range(S):
for xi in range(W):
# X coordinate of one end of the image. Reverse the ordering
# of xi so that +X is pointing to the left in the image.
xfix = S - 1 - xi
xf = pix_to_ndc(xfix, S)
xfix = W - 1 - xi
xf = pix_to_non_square_ndc(xfix, W, H)
top_k_points = []
# Check whether each point in the batch affects this pixel.

View File

@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple, Optional, Union
from typing import NamedTuple, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -29,7 +29,7 @@ class PointsRasterizationSettings:
def __init__(
self,
image_size: int = 256,
image_size: Union[int, Tuple[int, int]] = 256,
radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
bin_size: Optional[int] = None,