mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 23:30:35 +08:00
Non Square image rasterization for pointclouds
Summary: Similar to non square image rasterization for meshes, apply the same updates to the pointcloud rasterizer. Main API Change: - PointRasterizationSettings now accepts a tuple/list of (H, W) for the image size. Reviewed By: jcjohnson Differential Revision: D25465206 fbshipit-source-id: 7370d83c431af1b972158cecae19d82364623380
This commit is contained in:
committed by
Facebook GitHub Bot
parent
569e5229a9
commit
3d769a66cb
@@ -75,7 +75,7 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
|
||||
pixels with accumulated features have unchanged values.
|
||||
"""
|
||||
# Initialize background mask
|
||||
background_mask = pix_idxs[:, 0] < 0 # (N, image_size, image_size)
|
||||
background_mask = pix_idxs[:, 0] < 0 # (N, H, W)
|
||||
|
||||
# Convert background_color to an appropriate tensor and check shape
|
||||
if not torch.is_tensor(background_color):
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
|
||||
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
|
||||
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_non_square_ndc
|
||||
|
||||
|
||||
# Maxinum number of faces per bins for
|
||||
@@ -14,17 +14,30 @@ from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
|
||||
kMaxPointsPerBin = 22
|
||||
|
||||
|
||||
# TODO(jcjohns): Support non-square images
|
||||
def rasterize_points(
|
||||
pointclouds,
|
||||
image_size: int = 256,
|
||||
image_size: Union[int, List[int], Tuple[int, int]] = 256,
|
||||
radius: Union[float, List, Tuple, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
bin_size: Optional[int] = None,
|
||||
max_points_per_bin: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Pointcloud rasterization
|
||||
Each pointcloud is rasterized onto a separate image of shape
|
||||
(H, W) if `image_size` is a tuple or (image_size, image_size) if it
|
||||
is an int.
|
||||
|
||||
If the desired image size is non square (i.e. a tuple of (H, W) where H != W)
|
||||
the aspect ratio needs special consideration. There are two aspect ratios
|
||||
to be aware of:
|
||||
- the aspect ratio of each pixel
|
||||
- the aspect ratio of the output image
|
||||
The camera can be used to set the pixel aspect ratio. In the rasterizer,
|
||||
we assume square pixels, but variable image aspect ratio (i.e rectangle images).
|
||||
|
||||
In most cases you will want to set the camera aspect ratio to
|
||||
1.0 (i.e. square pixels) and only vary the
|
||||
`image_size` (i.e. the output image dimensions in pix
|
||||
|
||||
Args:
|
||||
pointclouds: A Pointclouds object representing a batch of point clouds to be
|
||||
@@ -34,7 +47,8 @@ def rasterize_points(
|
||||
be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at
|
||||
(0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left,
|
||||
the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front.
|
||||
image_size: Integer giving the resolution of the rasterized image
|
||||
image_size: Size in pixels of the output image to be rasterized.
|
||||
Can optionally be a tuple of (H, W) in the case of non square images.
|
||||
radius (Optional): The radius (in NDC units) of the disk to
|
||||
be rasterized. This can either be a float in which case the same radius is used
|
||||
for each point, or a torch.Tensor of shape (N, P) giving a radius per point
|
||||
@@ -71,6 +85,9 @@ def rasterize_points(
|
||||
then `dists[n, y, x, k]` is the squared distance between the pixel (y, x)
|
||||
and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer
|
||||
than points_per_pixel are padded with -1.
|
||||
|
||||
In the case that image_size is a tuple of (H, W) then the outputs
|
||||
will be of shape `(N, H, W, ...)`.
|
||||
"""
|
||||
points_packed = pointclouds.points_packed()
|
||||
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
|
||||
@@ -78,26 +95,46 @@ def rasterize_points(
|
||||
|
||||
radius = _format_radius(radius, pointclouds)
|
||||
|
||||
# In the case that H != W use the max image size to set the bin_size
|
||||
# to accommodate the num bins constraint in the coarse rasteizer.
|
||||
# If the ratio of H:W is large this might cause issues as the smaller
|
||||
# dimension will have fewer bins.
|
||||
# TODO: consider a better way of setting the bin size.
|
||||
if isinstance(image_size, (tuple, list)):
|
||||
if len(image_size) != 2:
|
||||
raise ValueError("Image size can only be a tuple/list of (H, W)")
|
||||
if not all(i > 0 for i in image_size):
|
||||
raise ValueError(
|
||||
"Image sizes must be greater than 0; got %d, %d" % image_size
|
||||
)
|
||||
if not all(type(i) == int for i in image_size):
|
||||
raise ValueError("Image sizes must be integers; got %f, %f" % image_size)
|
||||
max_image_size = max(*image_size)
|
||||
im_size = image_size
|
||||
else:
|
||||
im_size = (image_size, image_size)
|
||||
max_image_size = image_size
|
||||
|
||||
if bin_size is None:
|
||||
if not points_packed.is_cuda:
|
||||
# Binned CPU rasterization not fully implemented
|
||||
bin_size = 0
|
||||
else:
|
||||
# TODO: These heuristics are not well-thought out!
|
||||
if image_size <= 64:
|
||||
if max_image_size <= 64:
|
||||
bin_size = 8
|
||||
elif image_size <= 256:
|
||||
elif max_image_size <= 256:
|
||||
bin_size = 16
|
||||
elif image_size <= 512:
|
||||
elif max_image_size <= 512:
|
||||
bin_size = 32
|
||||
elif image_size <= 1024:
|
||||
elif max_image_size <= 1024:
|
||||
bin_size = 64
|
||||
|
||||
if bin_size != 0:
|
||||
# There is a limit on the number of points per bin in the cuda kernel.
|
||||
# pyre-fixme[58]: `//` is not supported for operand types `int` and
|
||||
# `Union[int, None, int]`.
|
||||
points_per_bin = 1 + (image_size - 1) // bin_size
|
||||
points_per_bin = 1 + (max_image_size - 1) // bin_size
|
||||
if points_per_bin >= kMaxPointsPerBin:
|
||||
raise ValueError(
|
||||
"bin_size too small, number of points per bin must be less than %d; got %d"
|
||||
@@ -114,7 +151,7 @@ def rasterize_points(
|
||||
points_packed,
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size,
|
||||
im_size,
|
||||
radius,
|
||||
points_per_pixel,
|
||||
bin_size,
|
||||
@@ -173,7 +210,7 @@ class _RasterizePoints(torch.autograd.Function):
|
||||
points, # (P, 3)
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size: int = 256,
|
||||
image_size: Union[List[int], Tuple[int, int]] = (256, 256),
|
||||
radius: Union[float, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
bin_size: int = 0,
|
||||
@@ -225,7 +262,7 @@ class _RasterizePoints(torch.autograd.Function):
|
||||
|
||||
def rasterize_points_python(
|
||||
pointclouds,
|
||||
image_size: int = 256,
|
||||
image_size: Union[int, Tuple[int, int]] = 256,
|
||||
radius: Union[float, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
):
|
||||
@@ -235,7 +272,12 @@ def rasterize_points_python(
|
||||
Inputs / Outputs: Same as above
|
||||
"""
|
||||
N = len(pointclouds)
|
||||
S, K = image_size, points_per_pixel
|
||||
H, W = (
|
||||
image_size
|
||||
if isinstance(image_size, (tuple, list))
|
||||
else (image_size, image_size)
|
||||
)
|
||||
K = points_per_pixel
|
||||
device = pointclouds.device
|
||||
|
||||
points_packed = pointclouds.points_packed()
|
||||
@@ -247,11 +289,11 @@ def rasterize_points_python(
|
||||
|
||||
# Intialize output tensors.
|
||||
point_idxs = torch.full(
|
||||
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
|
||||
(N, H, W, K), fill_value=-1, dtype=torch.int32, device=device
|
||||
)
|
||||
zbuf = torch.full((N, S, S, K), fill_value=-1, dtype=torch.float32, device=device)
|
||||
zbuf = torch.full((N, H, W, K), fill_value=-1, dtype=torch.float32, device=device)
|
||||
pix_dists = torch.full(
|
||||
(N, S, S, K), fill_value=-1, dtype=torch.float32, device=device
|
||||
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# NDC is from [-1, 1]. Get pixel size using specified image size.
|
||||
@@ -263,18 +305,18 @@ def rasterize_points_python(
|
||||
point_stop_idx = point_start_idx + num_points_per_cloud[n]
|
||||
|
||||
# Iterate through the horizontal lines of the image from top to bottom.
|
||||
for yi in range(S):
|
||||
for yi in range(H):
|
||||
# Y coordinate of one end of the image. Reverse the ordering
|
||||
# of yi so that +Y is pointing up in the image.
|
||||
yfix = S - 1 - yi
|
||||
yf = pix_to_ndc(yfix, S)
|
||||
yfix = H - 1 - yi
|
||||
yf = pix_to_non_square_ndc(yfix, H, W)
|
||||
|
||||
# Iterate through pixels on this horizontal line, left to right.
|
||||
for xi in range(S):
|
||||
for xi in range(W):
|
||||
# X coordinate of one end of the image. Reverse the ordering
|
||||
# of xi so that +X is pointing to the left in the image.
|
||||
xfix = S - 1 - xi
|
||||
xf = pix_to_ndc(xfix, S)
|
||||
xfix = W - 1 - xi
|
||||
xf = pix_to_non_square_ndc(xfix, W, H)
|
||||
|
||||
top_k_points = []
|
||||
# Check whether each point in the batch affects this pixel.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from typing import NamedTuple, Optional, Union
|
||||
from typing import NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -29,7 +29,7 @@ class PointsRasterizationSettings:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int = 256,
|
||||
image_size: Union[int, Tuple[int, int]] = 256,
|
||||
radius: Union[float, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
bin_size: Optional[int] = None,
|
||||
|
||||
Reference in New Issue
Block a user