Add utils to approximate the conical frustums as multivariate gaussians.

Summary:
Introduce methods to approximate the radii of conical frustums along rays as described in [MipNerf](https://arxiv.org/abs/2103.13415):

- Two new attributes are added to ImplicitronRayBundle: bins and radii. Bins is of size n_pts_per_ray + 1. It allows us to manipulate easily and n_pts_per_ray intervals. For example we need the intervals coordinates in the radii computation for \(t_{\mu}, t_{\delta}\). Radii are used to store the radii of the conical frustums.

- Add 3 new methods to compute the radii:
   - approximate_conical_frustum_as_gaussians: It computes the mean along the ray direction, the variance of the
      conical frustum  with respect to t and variance of the conical frustum with respect to its radius. This
      implementation follows the stable computation defined in the paper.
   - compute_3d_diagonal_covariance_gaussian: Will leverage the two previously computed variances to find the
     diagonal covariance of the Gaussian.
   - conical_frustum_to_gaussian: Mix everything together to compute the means and the diagonal covariances along
     the ray of the Gaussians.

- In AbstractMaskRaySampler, introduces the attribute `cast_ray_bundle_as_cone`. If False it won't change the previous behaviour of the RaySampler. However if True, the samplers will sample `n_pts_per_ray +1` instead of `n_pts_per_ray`. This points are then used to set the bins attribute of ImplicitronRayBundle. The support of HeterogeneousRayBundle has not been added since the current code does not allow it. A safeguard has been added to avoid a silent bug in the future.

Reviewed By: shapovalov

Differential Revision: D45269190

fbshipit-source-id: bf22fad12d71d55392f054e3f680013aa0d59b78
This commit is contained in:
Emilien Garreau
2023-07-06 01:55:41 -07:00
committed by Facebook GitHub Bot
parent 4e7715ce66
commit 29b8ebd802
10 changed files with 977 additions and 65 deletions

View File

@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import typing
import torch
from pytorch3d.common.datatypes import Device
from pytorch3d.renderer.cameras import (
CamerasBase,
FoVOrthographicCameras,
FoVPerspectiveCameras,
OpenGLOrthographicCameras,
OpenGLPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
SfMOrthographicCameras,
SfMPerspectiveCameras,
)
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
from pytorch3d.transforms.so3 import so3_exp_map
def init_random_cameras(
cam_type: typing.Type[CamerasBase],
batch_size: int,
random_z: bool = False,
device: Device = "cpu",
):
cam_params = {}
T = torch.randn(batch_size, 3) * 0.03
if not random_z:
T[:, 2] = 4
R = so3_exp_map(torch.randn(batch_size, 3) * 3.0)
cam_params = {"R": R, "T": T, "device": device}
if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras):
cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
if cam_type == OpenGLPerspectiveCameras:
cam_params["fov"] = torch.rand(batch_size) * 60 + 30
cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
else:
cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9
cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9
elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras):
cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
if cam_type == FoVPerspectiveCameras:
cam_params["fov"] = torch.rand(batch_size) * 60 + 30
cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
else:
cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9
cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9
elif cam_type in (
SfMOrthographicCameras,
SfMPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
):
cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1
cam_params["principal_point"] = torch.randn((batch_size, 2))
elif cam_type == FishEyeCameras:
cam_params["focal_length"] = torch.rand(batch_size, 1) * 10 + 0.1
cam_params["principal_point"] = torch.randn((batch_size, 2))
cam_params["radial_params"] = torch.randn((batch_size, 6))
cam_params["tangential_params"] = torch.randn((batch_size, 2))
cam_params["thin_prism_params"] = torch.randn((batch_size, 4))
else:
raise ValueError(str(cam_type))
return cam_type(**cam_params)

View File

@@ -62,6 +62,7 @@ raysampler_AdaptiveRaySampler_args:
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
cast_ray_bundle_as_cone: false
scene_extent: 8.0
scene_center:
- 0.0

View File

@@ -0,0 +1,254 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
import torch
from pytorch3d.implicitron.models.renderer.base import (
approximate_conical_frustum_as_gaussians,
compute_3d_diagonal_covariance_gaussian,
conical_frustum_to_gaussian,
ImplicitronRayBundle,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import AbstractMaskRaySampler
from tests.common_testing import TestCaseMixin
class TestRendererBase(TestCaseMixin, unittest.TestCase):
def test_implicitron_from_bins(self) -> None:
bins = torch.randn(2, 3, 4, 5)
ray_bundle = ImplicitronRayBundle.from_bins(
origins=None,
directions=None,
xys=None,
bins=bins,
)
self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1]))
self.assertClose(ray_bundle.bins, bins)
def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
with self.assertRaises(ValueError):
ImplicitronRayBundle.from_bins(
origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3),
xys=torch.rand(2, 3, 4, 2),
bins=torch.rand(2, 3, 4, 1),
)
def test_conical_frustum_to_gaussian(self) -> None:
origins = torch.zeros(3, 3, 3)
directions = torch.tensor(
[
[[0, 0, 0], [1, 0, 0], [3, 0, 0]],
[[0, 0.25, 0], [1, 0.25, 0], [3, 0.25, 0]],
[[0, 1, 0], [1, 1, 0], [3, 1, 0]],
]
)
bins = torch.tensor(
[
[[0.5, 1.5], [0.3, 0.7], [0.3, 0.7]],
[[0.5, 1.5], [0.3, 0.7], [0.3, 0.7]],
[[0.5, 1.5], [0.3, 0.7], [0.3, 0.7]],
]
)
# see test_compute_pixel_radii_from_ray_direction
radii = torch.tensor(
[
[1.25, 2.25, 2.25],
[1.75, 2.75, 2.75],
[1.75, 2.75, 2.75],
]
)
radii = radii[..., None] / 12**0.5
# The expected mean and diagonal covariance have been computed
# by hand from the official code of MipNerf.
# https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/mip.py#L125
# mean, cov_diag = cast_rays(length, origins, directions, radii, 'cone', diag=True)
expected_mean = torch.tensor(
[
[
[[0.0, 0.0, 0.0]],
[[0.5506329, 0.0, 0.0]],
[[1.6518986, 0.0, 0.0]],
],
[
[[0.0, 0.28846154, 0.0]],
[[0.5506329, 0.13765822, 0.0]],
[[1.6518986, 0.13765822, 0.0]],
],
[
[[0.0, 1.1538461, 0.0]],
[[0.5506329, 0.5506329, 0.0]],
[[1.6518986, 0.5506329, 0.0]],
],
]
)
expected_diag_cov = torch.tensor(
[
[
[[0.04544772, 0.04544772, 0.04544772]],
[[0.01130973, 0.03317059, 0.03317059]],
[[0.10178753, 0.03317059, 0.03317059]],
],
[
[[0.08907752, 0.00404956, 0.08907752]],
[[0.0142245, 0.04734321, 0.04955113]],
[[0.10212927, 0.04991625, 0.04955113]],
],
[
[[0.08907752, 0.0647929, 0.08907752]],
[[0.03608529, 0.03608529, 0.04955113]],
[[0.10674264, 0.05590574, 0.04955113]],
],
]
)
ray = ImplicitronRayBundle(
origins=origins,
directions=directions,
bins=bins,
lengths=None,
pixel_radii_2d=radii,
xys=None,
)
mean, diag_cov = conical_frustum_to_gaussian(ray)
self.assertClose(mean, expected_mean)
self.assertClose(diag_cov, expected_diag_cov)
def test_scale_conical_frustum_to_gaussian(self) -> None:
origins = torch.zeros(2, 2, 3)
directions = torch.Tensor(
[
[[0, 1, 0], [0, 0, 1]],
[[0, 1, 0], [0, 0, 1]],
]
)
bins = torch.Tensor(
[
[[0.5, 1.5], [0.3, 0.7]],
[[0.5, 1.5], [0.3, 0.7]],
]
)
radii = torch.ones(2, 2, 1)
ray = ImplicitronRayBundle(
origins=origins,
directions=directions,
bins=bins,
pixel_radii_2d=radii,
lengths=None,
xys=None,
)
mean, diag_cov = conical_frustum_to_gaussian(ray)
scaling_factor = 2.5
ray = ImplicitronRayBundle(
origins=origins,
directions=directions,
bins=bins * scaling_factor,
pixel_radii_2d=radii,
lengths=None,
xys=None,
)
mean_scaled, diag_cov_scaled = conical_frustum_to_gaussian(ray)
np.testing.assert_allclose(mean * scaling_factor, mean_scaled)
np.testing.assert_allclose(
diag_cov * scaling_factor**2, diag_cov_scaled, atol=1e-6
)
def test_approximate_conical_frustum_as_gaussian(self) -> None:
"""Ensure that the computation modularity in our function is well done."""
bins = torch.Tensor([[0.5, 1.5], [0.3, 0.7]])
radii = torch.Tensor([[1.0], [1.0]])
t_mean, t_var, r_var = approximate_conical_frustum_as_gaussians(bins, radii)
self.assertEqual(t_mean.shape, (2, 1))
self.assertEqual(t_var.shape, (2, 1))
self.assertEqual(r_var.shape, (2, 1))
mu = np.array([[1.0], [0.5]])
delta = np.array([[0.5], [0.2]])
np.testing.assert_allclose(
mu + (2 * mu * delta**2) / (3 * mu**2 + delta**2), t_mean.numpy()
)
np.testing.assert_allclose(
(delta**2) / 3
- (4 / 15)
* (
(delta**4 * (12 * mu**2 - delta**2))
/ (3 * mu**2 + delta**2) ** 2
),
t_var.numpy(),
)
np.testing.assert_allclose(
radii**2
* (
(mu**2) / 4
+ (5 / 12) * delta**2
- 4 / 15 * (delta**4) / (3 * mu**2 + delta**2)
),
r_var.numpy(),
)
def test_compute_3d_diagonal_covariance_gaussian(self) -> None:
ray_directions = torch.Tensor([[0, 0, 1]])
t_var = torch.Tensor([0.5, 0.5, 1])
r_var = torch.Tensor([0.6, 0.3, 0.4])
expected_diag_cov = np.array(
[
[
# t_cov_diag + xy_cov_diag
[0.0 + 0.6, 0.0 + 0.6, 0.5 + 0.0],
[0.0 + 0.3, 0.0 + 0.3, 0.5 + 0.0],
[0.0 + 0.4, 0.0 + 0.4, 1.0 + 0.0],
]
]
)
diag_cov = compute_3d_diagonal_covariance_gaussian(ray_directions, t_var, r_var)
np.testing.assert_allclose(diag_cov.numpy(), expected_diag_cov)
def test_conical_frustum_to_gaussian_raise_valueerror(self) -> None:
lengths = torch.linspace(0, 1, steps=6)
directions = torch.tensor([0, 0, 1])
origins = torch.tensor([1, 1, 1])
ray = ImplicitronRayBundle(
origins=origins, directions=directions, lengths=lengths, xys=None
)
with self.assertRaises(ValueError) as context:
_ = conical_frustum_to_gaussian(ray)
expected_error_message = (
"RayBundle pixel_radii_2d or bins have not been provided."
" Look at pytorch3d.renderer.implicit.renderer.ray_sampler::"
"AbstractMaskRaySampler to see how to compute them. Have you forgot to set"
"`cast_ray_bundle_as_cone` to True?"
)
self.assertEqual(expected_error_message, str(context.exception))
# Ensure message is coherent with AbstractMaskRaySampler
class FakeRaySampler(AbstractMaskRaySampler):
def _get_min_max_depth_bounds(self, *args):
return None
message_assertion = (
"If cast_ray_bundle_as_cone has been removed please update the doc"
"conical_frustum_to_gaussian"
)
self.assertIsNotNone(
getattr(FakeRaySampler(), "cast_ray_bundle_as_cone", None),
message_assertion,
)

View File

@@ -0,0 +1,290 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from itertools import product
from typing import Tuple
from unittest.mock import patch
import torch
from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.models.renderer.ray_sampler import (
AdaptiveRaySampler,
compute_radii,
NearFarRaySampler,
)
from pytorch3d.renderer.cameras import (
CamerasBase,
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
)
from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle
from tests.common_camera_utils import init_random_cameras
from tests.common_testing import TestCaseMixin
CAMERA_TYPES = (
FoVPerspectiveCameras,
FoVOrthographicCameras,
OrthographicCameras,
PerspectiveCameras,
)
def unproject_xy_grid_from_ndc_to_world_coord(
cameras: CamerasBase, xy_grid: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Unproject a xy_grid from NDC coordinates to world coordinates.
Args:
cameras: CamerasBase.
xy_grid: A tensor of shape `(..., H*W, 2)` representing the
x, y coords.
Returns:
A tensor of shape `(..., H*W, 3)` representing the
"""
batch_size = xy_grid.shape[0]
n_rays_per_image = xy_grid.shape[1:-1].numel()
xy = xy_grid.view(batch_size, -1, 2)
xyz = torch.cat([xy, xy_grid.new_ones(batch_size, n_rays_per_image, 1)], dim=-1)
plane_at_depth1 = cameras.unproject_points(xyz, from_ndc=True)
return plane_at_depth1.view(*xy_grid.shape[:-1], 3)
class TestRaysampler(TestCaseMixin, unittest.TestCase):
def test_ndc_raysampler_n_ray_total_is_none(self):
sampler = NearFarRaySampler()
message = (
"If you introduce the support of `n_rays_total` for {0}, please handle the "
"packing and unpacking logic for the radii and lengths computation."
)
self.assertIsNone(
sampler._training_raysampler._n_rays_total, message.format(type(sampler))
)
self.assertIsNone(
sampler._evaluation_raysampler._n_rays_total, message.format(type(sampler))
)
sampler = AdaptiveRaySampler()
self.assertIsNone(
sampler._training_raysampler._n_rays_total, message.format(type(sampler))
)
self.assertIsNone(
sampler._evaluation_raysampler._n_rays_total, message.format(type(sampler))
)
def test_catch_heterogeneous_exception(self):
cameras = init_random_cameras(FoVPerspectiveCameras, 1, random_z=True)
class FakeSampler:
def __init__(self):
self.min_x, self.max_x = 1, 2
self.min_y, self.max_y = 1, 2
def __call__(self, **kwargs):
return HeterogeneousRayBundle(
torch.rand(3), torch.rand(3), torch.rand(3), torch.rand(1)
)
with patch(
"pytorch3d.implicitron.models.renderer.ray_sampler.NDCMultinomialRaysampler",
return_value=FakeSampler(),
):
for sampler in [
AdaptiveRaySampler(cast_ray_bundle_as_cone=True),
NearFarRaySampler(cast_ray_bundle_as_cone=True),
]:
with self.assertRaises(TypeError):
_ = sampler(cameras, EvaluationMode.TRAINING)
for sampler in [
AdaptiveRaySampler(cast_ray_bundle_as_cone=False),
NearFarRaySampler(cast_ray_bundle_as_cone=False),
]:
_ = sampler(cameras, EvaluationMode.TRAINING)
def test_compute_radii(self):
batch_size = 1
image_height, image_width = 20, 10
min_y, max_y, min_x, max_x = -1.0, 1.0, -1.0, 1.0
y, x = meshgrid_ij(
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
)
xy_grid = torch.stack([x, y], dim=-1).view(-1, 2)
pixel_width = (max_x - min_x) / (image_width - 1)
pixel_height = (max_y - min_y) / (image_height - 1)
for cam_type in CAMERA_TYPES:
# init a batch of random cameras
cameras = init_random_cameras(cam_type, batch_size, random_z=True)
# This method allow us to compute the radii whithout having
# access to the full grid. Raysamplers during the training
# will sample random rays from the grid.
radii = compute_radii(
cameras, xy_grid, pixel_hw_ndc=(pixel_height, pixel_width)
)
plane_at_depth1 = unproject_xy_grid_from_ndc_to_world_coord(
cameras, xy_grid
)
# This method absolutely needs the full grid to work.
expected_radii = compute_pixel_radii_from_grid(
plane_at_depth1.reshape(1, image_height, image_width, 3)
)
self.assertClose(expected_radii.reshape(-1, 1), radii)
def test_forward(self):
n_rays_per_image = 16
image_height, image_width = 20, 20
kwargs = {
"image_width": image_width,
"image_height": image_height,
"n_pts_per_ray_training": 32,
"n_pts_per_ray_evaluation": 32,
"n_rays_per_image_sampled_from_mask": n_rays_per_image,
"cast_ray_bundle_as_cone": False,
}
batch_size = 2
samplers = [NearFarRaySampler(**kwargs), AdaptiveRaySampler(**kwargs)]
evaluation_modes = [EvaluationMode.TRAINING, EvaluationMode.EVALUATION]
for cam_type, sampler, evaluation_mode in product(
CAMERA_TYPES, samplers, evaluation_modes
):
cameras = init_random_cameras(cam_type, batch_size, random_z=True)
ray_bundle = sampler(cameras, evaluation_mode)
shape_out = (
(batch_size, image_width, image_height)
if evaluation_mode == EvaluationMode.EVALUATION
else (batch_size, n_rays_per_image, 1)
)
n_pts_per_ray = (
kwargs["n_pts_per_ray_evaluation"]
if evaluation_mode == EvaluationMode.EVALUATION
else kwargs["n_pts_per_ray_training"]
)
self.assertIsNone(ray_bundle.bins)
self.assertIsNone(ray_bundle.pixel_radii_2d)
self.assertEqual(
ray_bundle.lengths.shape,
(*shape_out, n_pts_per_ray),
)
self.assertEqual(ray_bundle.directions.shape, (*shape_out, 3))
self.assertEqual(ray_bundle.origins.shape, (*shape_out, 3))
def test_forward_with_use_bins(self):
n_rays_per_image = 16
image_height, image_width = 20, 20
kwargs = {
"image_width": image_width,
"image_height": image_height,
"n_pts_per_ray_training": 32,
"n_pts_per_ray_evaluation": 32,
"n_rays_per_image_sampled_from_mask": n_rays_per_image,
"cast_ray_bundle_as_cone": True,
}
batch_size = 1
samplers = [NearFarRaySampler(**kwargs), AdaptiveRaySampler(**kwargs)]
evaluation_modes = [EvaluationMode.TRAINING, EvaluationMode.EVALUATION]
for cam_type, sampler, evaluation_mode in product(
CAMERA_TYPES, samplers, evaluation_modes
):
cameras = init_random_cameras(cam_type, batch_size, random_z=True)
ray_bundle = sampler(cameras, evaluation_mode)
lengths = 0.5 * (ray_bundle.bins[..., :-1] + ray_bundle.bins[..., 1:])
self.assertClose(ray_bundle.lengths, lengths)
shape_out = (
(batch_size, image_width, image_height)
if evaluation_mode == EvaluationMode.EVALUATION
else (batch_size, n_rays_per_image, 1)
)
self.assertEqual(ray_bundle.pixel_radii_2d.shape, (*shape_out, 1))
self.assertEqual(ray_bundle.directions.shape, (*shape_out, 3))
self.assertEqual(ray_bundle.origins.shape, (*shape_out, 3))
# Helper to test compute_radii
def compute_pixel_radii_from_grid(pixel_grid: torch.Tensor) -> torch.Tensor:
"""
Compute the radii of a conical frustum given the pixel grid.
To compute the radii we first compute the translation from a pixel
to its neighbors along the x and y axis. Then, we compute the norm
of each translation along the x and y axis.
The radii are then obtained by the following formula:
(dx_norm + dy_norm) * 0.5 * 2 / 12**0.5
where 2/12**0.5 is a scaling factor to match
the variance of the pixels footprint.
Args:
pixel_grid: A tensor of shape `(..., H, W, dim)` representing the
full grid of rays pixel_grid.
Returns:
The radiis for each pixels and shape `(..., H, W, 1)`.
"""
# [B, H, W - 1, 3]
x_translation = torch.diff(pixel_grid, dim=-2)
# [B, H - 1, W, 3]
y_translation = torch.diff(pixel_grid, dim=-3)
# [B, H, W - 1, 1]
dx_norm = torch.linalg.norm(x_translation, dim=-1, keepdim=True)
# [B, H - 1, W, 1]
dy_norm = torch.linalg.norm(y_translation, dim=-1, keepdim=True)
# Fill the missing value [B, H, W, 1]
dx_norm = torch.concatenate([dx_norm, dx_norm[..., -1:, :]], -2)
dy_norm = torch.concatenate([dy_norm, dy_norm[..., -1:, :, :]], -3)
# Cut the distance in half to obtain the base radius: (dx_norm + dy_norm) * 0.5
# and multiply it by the scaling factor: * 2 / 12**0.5
radii = (dx_norm + dy_norm) / 12**0.5
return radii
class TestRadiiComputationOnFullGrid(TestCaseMixin, unittest.TestCase):
def test_compute_pixel_radii_from_grid(self):
pixel_grid = torch.tensor(
[
[[0.0, 0, 0], [1.0, 0.0, 0], [3.0, 0.0, 0.0]],
[[0.0, 0.25, 0], [1.0, 0.25, 0], [3.0, 0.25, 0]],
[[0.0, 1, 0], [1.0, 1.0, 0], [3.0000, 1.0, 0]],
]
)
expected_y_norm = torch.tensor(
[
[0.25, 0.25, 0.25],
[0.75, 0.75, 0.75],
[0.75, 0.75, 0.75], # duplicated from previous row
]
)
expected_x_norm = torch.tensor(
[
# 3rd column is duplicated from 2nd
[1.0, 2.0, 2.0],
[1.0, 2.0, 2.0],
[1.0, 2.0, 2.0],
]
)
expected_radii = (expected_x_norm + expected_y_norm) / 12**0.5
radii = compute_pixel_radii_from_grid(pixel_grid)
self.assertClose(radii, expected_radii[..., None])

View File

@@ -32,7 +32,6 @@
import math
import pickle
import typing
import unittest
from itertools import product
@@ -60,6 +59,8 @@ from pytorch3d.transforms import Transform3d
from pytorch3d.transforms.rotation_conversions import random_rotations
from pytorch3d.transforms.so3 import so3_exp_map
from .common_camera_utils import init_random_cameras
from .common_testing import TestCaseMixin
@@ -151,60 +152,6 @@ def ndc_to_screen_points_naive(points, imsize):
return torch.stack((x, y, z), dim=2)
def init_random_cameras(
cam_type: typing.Type[CamerasBase],
batch_size: int,
random_z: bool = False,
device: Device = "cpu",
):
cam_params = {}
T = torch.randn(batch_size, 3) * 0.03
if not random_z:
T[:, 2] = 4
R = so3_exp_map(torch.randn(batch_size, 3) * 3.0)
cam_params = {"R": R, "T": T, "device": device}
if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras):
cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
if cam_type == OpenGLPerspectiveCameras:
cam_params["fov"] = torch.rand(batch_size) * 60 + 30
cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
else:
cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9
cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9
elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras):
cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
if cam_type == FoVPerspectiveCameras:
cam_params["fov"] = torch.rand(batch_size) * 60 + 30
cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
else:
cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9
cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9
elif cam_type in (
SfMOrthographicCameras,
SfMPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
):
cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1
cam_params["principal_point"] = torch.randn((batch_size, 2))
elif cam_type == FishEyeCameras:
cam_params["focal_length"] = torch.rand(batch_size, 1) * 10 + 0.1
cam_params["principal_point"] = torch.randn((batch_size, 2))
cam_params["radial_params"] = torch.randn((batch_size, 6))
cam_params["tangential_params"] = torch.randn((batch_size, 2))
cam_params["thin_prism_params"] = torch.randn((batch_size, 4))
else:
raise ValueError(str(cam_type))
return cam_type(**cam_params)
class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()