Convert from Pytorch3D NDC coordinates to grid_sample coordinates.

Summary: Implements a utility function to convert from 2D coordinates in Pytorch3D NDC space to the coordinates in grid_sample.

Reviewed By: shapovalov

Differential Revision: D33741394

fbshipit-source-id: 88981653356588fe646e6dea48fe7f7298738437
This commit is contained in:
David Novotny 2022-02-09 12:48:47 -08:00 committed by Facebook GitHub Bot
parent 47c0997227
commit 12f20d799e
3 changed files with 260 additions and 3 deletions

View File

@ -70,7 +70,12 @@ from .points import (
PulsarPointsRenderer, PulsarPointsRenderer,
rasterize_points, rasterize_points,
) )
from .utils import TensorProperties, convert_to_tensors_and_broadcast from .utils import (
TensorProperties,
convert_to_tensors_and_broadcast,
ndc_to_grid_sample_coords,
ndc_grid_sample,
)
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -8,7 +8,7 @@
import copy import copy
import inspect import inspect
import warnings import warnings
from typing import Any, Optional, Union from typing import Any, Optional, Union, Tuple
import numpy as np import numpy as np
import torch import torch
@ -350,3 +350,80 @@ def convert_to_tensors_and_broadcast(
args_Nd.append(c.expand(*expand_sizes)) args_Nd.append(c.expand(*expand_sizes))
return args_Nd return args_Nd
def ndc_grid_sample(
input: torch.Tensor,
grid_ndc: torch.Tensor,
**grid_sample_kwargs,
) -> torch.Tensor:
"""
Samples a tensor `input` of shape `(B, dim, H, W)` at 2D locations
specified by a tensor `grid_ndc` of shape `(B, ..., 2)` using
the `torch.nn.functional.grid_sample` function.
`grid_ndc` is specified in PyTorch3D NDC coordinate frame.
Args:
input: The tensor of shape `(B, dim, H, W)` to be sampled.
grid_ndc: A tensor of shape `(B, ..., 2)` denoting the set of
2D locations at which `input` is sampled.
See [1] for a detailed description of the NDC coordinates.
grid_sample_kwargs: Additional arguments forwarded to the
`torch.nn.functional.grid_sample` call. See the corresponding
docstring for a listing of the corresponding arguments.
Returns:
sampled_input: A tensor of shape `(B, dim, ...)` containing the samples
of `input` at 2D locations `grid_ndc`.
References:
[1] https://pytorch3d.org/docs/cameras
"""
batch, *spatial_size, pt_dim = grid_ndc.shape
if batch != input.shape[0]:
raise ValueError("'input' and 'grid_ndc' have to have the same batch size.")
if input.ndim != 4:
raise ValueError("'input' has to be a 4-dimensional Tensor.")
if pt_dim != 2:
raise ValueError("The last dimension of 'grid_ndc' has to be == 2.")
grid_ndc_flat = grid_ndc.reshape(batch, -1, 1, 2)
grid_flat = ndc_to_grid_sample_coords(grid_ndc_flat, input.shape[2:])
sampled_input_flat = torch.nn.functional.grid_sample(
input, grid_flat, **grid_sample_kwargs
)
sampled_input = sampled_input_flat.reshape([batch, input.shape[1], *spatial_size])
return sampled_input
def ndc_to_grid_sample_coords(
xy_ndc: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> torch.Tensor:
"""
Convert from the PyTorch3D's NDC coordinates to
`torch.nn.functional.grid_sampler`'s coordinates.
Args:
xy_ndc: Tensor of shape `(..., 2)` containing 2D points in the
PyTorch3D's NDC coordinates.
image_size_hw: A tuple `(image_height, image_width)` denoting the
height and width of the image tensor to sample.
Returns:
xy_grid_sample: Tensor of shape `(..., 2)` containing 2D points in the
`torch.nn.functional.grid_sample` coordinates.
"""
if len(image_size_hw) != 2 or any(s <= 0 for s in image_size_hw):
raise ValueError("'image_size_hw' has to be a 2-tuple of positive integers")
aspect = min(image_size_hw) / max(image_size_hw)
xy_grid_sample = -xy_ndc # first negate the coords
if image_size_hw[0] >= image_size_hw[1]:
xy_grid_sample[..., 1] *= aspect
else:
xy_grid_sample[..., 0] *= aspect
return xy_grid_sample

View File

@ -10,7 +10,20 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.renderer.utils import TensorProperties from pytorch3d.ops import eyes
from pytorch3d.renderer import (
PerspectiveCameras,
AlphaCompositor,
PointsRenderer,
PointsRasterizationSettings,
PointsRasterizer,
)
from pytorch3d.renderer.utils import (
TensorProperties,
ndc_to_grid_sample_coords,
ndc_grid_sample,
)
from pytorch3d.structures import Pointclouds
# Example class for testing # Example class for testing
@ -96,3 +109,165 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase):
# the input. # the input.
self.assertClose(test_class_gathered.x[inds].mean(dim=0), x[i, ...]) self.assertClose(test_class_gathered.x[inds].mean(dim=0), x[i, ...])
self.assertClose(test_class_gathered.y[inds].mean(dim=0), y[i, ...]) self.assertClose(test_class_gathered.y[inds].mean(dim=0), y[i, ...])
def test_ndc_grid_sample_rendering(self):
"""
Use PyTorch3D point renderer to render a colored point cloud, then
sample the image at the locations of the point projections with
`ndc_grid_sample`. Finally, assert that the sampled colors are equal to the
original point cloud colors.
Note that, in order to ensure correctness, we use a nearest-neighbor
assignment point renderer (i.e. no soft splatting).
"""
# generate a bunch of 3D points on a regular grid lying in the z-plane
n_grid_pts = 10
grid_scale = 0.9
z_plane = 2.0
image_size = [128, 128]
point_radius = 0.015
n_pts = n_grid_pts * n_grid_pts
pts = torch.stack(
torch.meshgrid(
[torch.linspace(-grid_scale, grid_scale, n_grid_pts)] * 2, indexing="ij"
),
dim=-1,
)
pts = torch.cat([pts, z_plane * torch.ones_like(pts[..., :1])], dim=-1)
pts = pts.reshape(1, n_pts, 3)
# color the points randomly
pts_colors = torch.rand(1, n_pts, 3)
# make trivial rendering cameras
cameras = PerspectiveCameras(
R=eyes(dim=3, N=1),
device=pts.device,
T=torch.zeros(1, 3, dtype=torch.float32, device=pts.device),
)
# render the point cloud
pcl = Pointclouds(points=pts, features=pts_colors)
renderer = NearestNeighborPointsRenderer(
rasterizer=PointsRasterizer(
cameras=cameras,
raster_settings=PointsRasterizationSettings(
image_size=image_size,
radius=point_radius,
points_per_pixel=1,
),
),
compositor=AlphaCompositor(),
)
im_render = renderer(pcl)
# sample the render at projected pts
pts_proj = cameras.transform_points(pcl.points_padded())[..., :2]
pts_colors_sampled = ndc_grid_sample(
im_render,
pts_proj,
mode="nearest",
align_corners=False,
).permute(0, 2, 1)
# assert that the samples are the same as original points
self.assertClose(pts_colors, pts_colors_sampled, atol=1e-4)
def test_ndc_to_grid_sample_coords(self):
"""
Test the conversion from ndc to grid_sample coords by comparing
to known conversion results.
"""
# square image tests
image_size_square = [100, 100]
xy_ndc_gs_square = torch.FloatTensor(
[
# 4 corners
[[-1.0, -1.0], [1.0, 1.0]],
[[1.0, 1.0], [-1.0, -1.0]],
[[1.0, -1.0], [-1.0, 1.0]],
[[1.0, 1.0], [-1.0, -1.0]],
# center
[[0.0, 0.0], [0.0, 0.0]],
]
)
# non-batched version
for xy_ndc, xy_gs in xy_ndc_gs_square:
xy_gs_predicted = ndc_to_grid_sample_coords(
xy_ndc,
image_size_square,
)
self.assertClose(xy_gs_predicted, xy_gs)
# batched version
xy_ndc, xy_gs = xy_ndc_gs_square[:, 0], xy_ndc_gs_square[:, 1]
xy_gs_predicted = ndc_to_grid_sample_coords(
xy_ndc,
image_size_square,
)
self.assertClose(xy_gs_predicted, xy_gs)
# non-square image tests
image_size = [100, 200]
xy_ndc_gs = torch.FloatTensor(
[
# 4 corners
[[-2.0, -1.0], [1.0, 1.0]],
[[2.0, -1.0], [-1.0, 1.0]],
[[-2.0, 1.0], [1.0, -1.0]],
[[2.0, 1.0], [-1.0, -1.0]],
# center
[[0.0, 0.0], [0.0, 0.0]],
# non-corner points
[[4.0, 0.5], [-2.0, -0.5]],
[[1.0, -0.5], [-0.5, 0.5]],
]
)
# check both H > W and W > H
for flip_axes in [False, True]:
# non-batched version
for xy_ndc, xy_gs in xy_ndc_gs:
xy_gs_predicted = ndc_to_grid_sample_coords(
xy_ndc.flip(dims=(-1,)) if flip_axes else xy_ndc,
list(reversed(image_size)) if flip_axes else image_size,
)
self.assertClose(
xy_gs_predicted,
xy_gs.flip(dims=(-1,)) if flip_axes else xy_gs,
)
# batched version
xy_ndc, xy_gs = xy_ndc_gs[:, 0], xy_ndc_gs[:, 1]
xy_gs_predicted = ndc_to_grid_sample_coords(
xy_ndc.flip(dims=(-1,)) if flip_axes else xy_ndc,
list(reversed(image_size)) if flip_axes else image_size,
)
self.assertClose(
xy_gs_predicted,
xy_gs.flip(dims=(-1,)) if flip_axes else xy_gs,
)
class NearestNeighborPointsRenderer(PointsRenderer):
"""
A class for rendering a batch of points by a trivial nearest
neighbor assignment.
"""
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
fragments = self.rasterizer(point_clouds, **kwargs)
# set all weights trivially to one
dists2 = fragments.dists.permute(0, 3, 1, 2)
weights = torch.ones_like(dists2)
images = self.compositor(
fragments.idx.long().permute(0, 3, 1, 2),
weights,
point_clouds.features_packed().permute(1, 0),
**kwargs,
)
return images