Fisheye Camera for PyTorch3D

Summary:
1. A Fisheye camera model that generalizes pinhole camera by considering distortions (i.e. radial, tangential and thin-prism distortions).

2. Added tests against perspective cameras when distortions are off and Aria data points when distortions are on.

3. Address comments to test unhandled shapes between points and transforms. Added tests for __FIELDS, shape broadcasts, cuda etc.

4. Address earlier comments for code efficiency (e.g., adopted torch.norm; torch.solve for matrix inverse; removed inplace operations; unnecessary clone; expand in place of repeat etc).

Reviewed By: jcjohnson

Differential Revision: D38407094

fbshipit-source-id: a3ab48c85c496ac87af692d5d461bb3fc2a2db13
This commit is contained in:
Jiali Duan 2022-08-28 11:17:20 -07:00 committed by Facebook GitHub Bot
parent 4711d12a09
commit 2283c292a9
2 changed files with 907 additions and 1 deletions

View File

@ -0,0 +1,585 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import List, Optional, Tuple, Union
import torch
from pytorch3d.common.datatypes import Device
from pytorch3d.renderer.cameras import _R, _T, CamerasBase
_focal_length = torch.tensor(((1.0,),))
_principal_point = torch.tensor(((0.0, 0.0),))
_radial_params = torch.tensor(((0.0, 0.0, 0.0, 0.0, 0.0, 0.0),))
_tangential_params = torch.tensor(((0.0, 0.0),))
_thin_prism_params = torch.tensor(((0.0, 0.0, 0.0, 0.0),))
class FishEyeCameras(CamerasBase):
"""
A class which extends Pinhole camera by considering radial, tangential and
thin-prism distortion. For the fisheye camera model, k1, k2, ..., k_n_radial are
polynomial coefficents to model radial distortions. Two common types of radial
distortions are barrel and pincusion radial distortions.
a = x / z, b = y / z, r = (a*a+b*b)^(1/2)
th = atan(r)
[x_r] = (th+ k0 * th^3 + k1* th^5 + ...) [a/r]
[y_r] [b/r] [1]
The tangential distortion parameters are p1 and p2. The primary cause is
due to the lens assembly not being centered over and parallel to the image plane.
tangentialDistortion = [(2 x_r^2 + rd^2)*p_0 + 2*x_r*y_r*p_1]
[(2 y_r^2 + rd^2)*p_1 + 2*x_r*y_r*p_0] [2]
where rd^2 = x_r^2 + y_r^2
The thin-prism distortion is modeled with s1, s2, s3, s4 coefficients
thinPrismDistortion = [s0 * rd^2 + s1 rd^4]
[s2 * rd^2 + s3 rd^4] [3]
The projection
proj = diag(f, f) * uvDistorted + [cu; cv]
uvDistorted = [x_r] + tangentialDistortion + thinPrismDistortion [4]
[y_r]
f is the focal length and cu, cv are principal points in x, y axis.
"""
_FIELDS = (
"focal_length",
"principal_point",
"R",
"T",
"radial_params",
"tangential_params",
"thin_prism_params",
"world_coordinates",
"use_radial",
"use_tangential",
"use_tin_prism",
"device",
"image_size",
)
def __init__(
self,
focal_length=_focal_length,
principal_point=_principal_point,
radial_params=_radial_params,
tangential_params=_tangential_params,
thin_prism_params=_thin_prism_params,
R: torch.Tensor = _R,
T: torch.Tensor = _T,
world_coordinates: bool = False,
use_radial: bool = True,
use_tangential: bool = True,
use_thin_prism: bool = True,
device: Device = "cpu",
image_size: Optional[Union[List, Tuple, torch.Tensor]] = None,
) -> None:
"""
Args:
focal_ength: Focal length of the camera in world units.
A tensor of shape (N, 1) for square pixels,
where N is number of transforms.
principal_point: xy coordinates of the center of
the principal point of the camera in pixels.
A tensor of shape (N, 2).
radial_params: parameters for radial distortions.
A tensor of shape (N, num_radial).
tangential_params:parameters for tangential distortions.
A tensor of shape (N, 2).
thin_prism_params: parameters for thin-prism distortions.
A tensor of shape (N, 4).
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
world_coordinates: if True, project from world coordinates; otherwise from camera
coordinates
use_radial: radial_distortion, default to True
use_tangential: tangential distortion, default to True
use_thin_prism: thin prism distortion, default to True
device: torch.device or string
image_size: (height, width) of image size.
A tensor of shape (N, 2) or a list/tuple. Required for screen cameras.
"""
kwargs = {"image_size": image_size} if image_size is not None else {}
super().__init__(
device=device,
R=R,
T=T,
**kwargs, # pyre-ignore
)
if image_size is not None:
if (self.image_size < 1).any(): # pyre-ignore
raise ValueError("Image_size provided has invalid values")
else:
self.image_size = None
self.device = device
self.focal = focal_length.to(self.device)
self.principal_point = principal_point.to(self.device)
self.radial_params = radial_params.to(self.device)
self.tangential_params = tangential_params.to(self.device)
self.thin_prism_params = thin_prism_params.to(self.device)
self.R = R
self.T = T
self.world_coordinates = world_coordinates
self.use_radial = use_radial
self.use_tangential = use_tangential
self.use_thin_prism = use_thin_prism
self.epsilon = 1e-10
self.num_distortion_iters = 50
self.R = self.R.to(self.device)
self.T = self.T.to(self.device)
self.num_radial = radial_params.shape[-1]
def _project_points_batch(
self,
focal,
principal_point,
radial_params,
tangential_params,
thin_prism_params,
points,
) -> torch.Tensor:
"""
Takes in points in the local reference frame of the camera and projects it
onto the image plan. Since this is a symmetric model, points with negative z are
projected to the positive sphere. i.e project(1,1,-1) == project(-1,-1,1)
Args:
focal: (1)
principal_point: (2)
radial_params: (num_radial)
tangential_params: (2)
thin_prism_params: (4)
points in the camera coordinate frame: (..., 3). E.g., (P, 3) (1, P, 3)
or (M, P, 3) where P is the number of points
Returns:
projected_points in the image plane: (..., 3). E.g., (P, 3) or
(1, P, 3) or (M, P, 3)
"""
assert points.shape[-1] == 3, "points shape incorrect"
ab = points[..., :2] / points[..., 2:]
uv_distorted = ab
r = ab.norm(dim=-1)
th = r.atan()
theta_sq = th * th
# compute radial distortions, eq 1
t = theta_sq
theta_pow = torch.stack([t, t**2, t**3, t**4, t**5, t**6], dim=-1)
th_radial = 1 + torch.sum(theta_pow * radial_params, dim=-1)
# compute th/r, using the limit for small values
th_divr = th / r
boolean_mask = abs(r) < self.epsilon
th_divr[boolean_mask] = 1.0
# the distorted coordinates -- except for focal length and principal point
# start with the radial term
coeff = th_radial * th_divr
xr_yr = coeff[..., None] * ab
xr_yr_squared_norm = torch.pow(xr_yr, 2).sum(dim=-1, keepdim=True)
if self.use_radial:
uv_distorted = xr_yr
# compute tangential distortions, eq 2
if self.use_tangential:
temp = 2 * torch.sum(
xr_yr * tangential_params,
dim=-1,
)
uv_distorted = uv_distorted + (
temp[..., None] * xr_yr + xr_yr_squared_norm * tangential_params
)
# compute thin-prism distortions, eq 3
sh = uv_distorted.shape[:-1]
if self.use_thin_prism:
radial_powers = torch.cat(
[xr_yr_squared_norm, xr_yr_squared_norm * xr_yr_squared_norm], dim=-1
)
uv_distorted[..., 0] = uv_distorted[..., 0] + torch.sum(
thin_prism_params[..., 0:2] * radial_powers,
dim=-1,
)
uv_distorted[..., 1] = uv_distorted[..., 1] + torch.sum(
thin_prism_params[..., 2:4] * radial_powers,
dim=-1,
)
# return value: distorted points on the uv plane, eq 4
projected_points = focal * uv_distorted + principal_point
return torch.cat(
[projected_points, torch.ones(list(sh) + [1], device=self.device)], dim=-1
)
def check_input(self, points: torch.Tensor, batch_size: int):
"""
Check if the shapes are broadcastable between points and transforms.
Accept points of shape (P, 3) or (1, P, 3) or (M, P, 3). The batch_size
for transforms should be 1 when points take (M, P, 3). The batch_size
can be 1 or N when points take shape (P, 3).
Args:
points: tensor of shape (P, 3) or (1, P, 3) or (M, P, 3)
batch_size: number of transforms
Returns:
Boolean value if the input shapes are compatible.
"""
if points.ndim > 3:
return False
if points.ndim == 3:
M, P, K = points.shape
if K != 3:
return False
if M > 1 and batch_size > 1:
return False
return True
def transform_points(
self, points, eps: Optional[float] = None, **kwargs
) -> torch.Tensor:
"""
Transform input points from camera space to image space.
Args:
points: tensor of (..., 3). E.g., (P, 3) or (1, P, 3), (M, P, 3)
eps: tiny number to avoid zero divsion
Returns:
torch.Tensor
when points take shape (P, 3) or (1, P, 3), output is (N, P, 3)
when points take shape (M, P, 3), output is (M, P, 3)
where N is the number of transforms, P number of points
"""
# project from world space to camera space
if self.world_coordinates:
world_to_view_transform = self.get_world_to_view_transform(
R=self.R, T=self.T
)
points = world_to_view_transform.transform_points(
points.to(self.device), eps=eps
)
else:
points = points.to(self.device)
# project from camera space to image space
N = len(self.radial_params)
if not self.check_input(points, N):
msg = "Expected points of (P, 3) with batch_size 1 or N, or shape (M, P, 3) \
with batch_size 1; got points of shape %r and batch_size %r"
raise ValueError(msg % (points.shape, N))
if N == 1:
return self._project_points_batch(
self.focal[0],
self.principal_point[0],
self.radial_params[0],
self.tangential_params[0],
self.thin_prism_params[0],
points,
)
else:
outputs = []
for i in range(N):
outputs.append(
self._project_points_batch(
self.focal[i],
self.principal_point[i],
self.radial_params[i],
self.tangential_params[i],
self.thin_prism_params[i],
points,
)
)
outputs = torch.stack(outputs, dim=0)
return outputs.squeeze()
def _unproject_points_batch(
self,
focal,
principal_point,
radial_params,
tangential_params,
thin_prism_params,
xy: torch.Tensor,
) -> torch.Tensor:
"""
Args:
focal: (1)
principal_point: (2)
radial_params: (num_radial)
tangential_params: (2)
thin_prism_params: (4)
xy: (..., 2)
Returns:
point3d_est: (..., 3)
"""
sh = list(xy.shape[:-1])
assert xy.shape[-1] == 2, "xy_depth shape incorrect"
uv_distorted = (xy - principal_point) / focal
# get xr_yr from uvDistorted
xr_yr = self._compute_xr_yr_from_uv_distorted(
tangential_params, thin_prism_params, uv_distorted
)
xr_yrNorm = torch.norm(xr_yr, dim=-1)
# find theta
theta = self._get_theta_from_norm_xr_yr(radial_params, xr_yrNorm)
# get the point coordinates:
point3d_est = theta.new_ones(*sh, 3)
point3d_est[..., :2] = theta.tan()[..., None] / xr_yrNorm[..., None] * xr_yr
return point3d_est
def unproject_points(
self,
xy_depth: torch.Tensor,
world_coordinates: bool = True,
scaled_depth_input: bool = False,
**kwargs,
) -> torch.Tensor:
"""
Takes in 3-point ``uv_depth`` in the image plane of the camera and unprojects it
into the reference frame of the camera.
This function is the inverse of ``transform_points``. In particular it holds that
X = unproject(project(X))
and
x = project(unproject(s*x))
Args:
xy_depth: points in the image plane of shape (..., 3). E.g.,
(P, 3) or (1, P, 3) or (M, P, 3)
world_coordinates: if the output is in world_coordinate, if False, convert to
camera coordinate
scaled_depth_input: False
Returns:
unprojected_points in the camera frame with z = 1
when points take shape (P, 3) or (1, P, 3), output is (N, P, 3)
when points take shape (M, P, 3), output is (M, P, 3)
where N is the number of transforms, P number of point
"""
xy_depth = xy_depth.to(self.device)
N = len(self.radial_params)
if N == 1:
return self._unproject_points_batch(
self.focal[0],
self.principal_point[0],
self.radial_params[0],
self.tangential_params[0],
self.thin_prism_params[0],
xy_depth[..., 0:2],
)
else:
outputs = []
for i in range(N):
outputs.append(
self._unproject_points_batch(
self.focal[i],
self.principal_point[i],
self.radial_params[i],
self.tangential_params[i],
self.thin_prism_params[i],
xy_depth[..., 0:2],
)
)
outputs = torch.stack(outputs, dim=0)
return outputs.squeeze()
def _compute_xr_yr_from_uv_distorted(
self, tangential_params, thin_prism_params, uv_distorted: torch.Tensor
) -> torch.Tensor:
"""
Helper function to compute the vector [x_r; y_r] from uvDistorted
Args:
tangential_params: (2)
thin_prism_params: (4)
uv_distorted: (..., 2), E.g., (P, 2), (1, P, 2), (M, P, 2)
Returns:
xr_yr: (..., 2)
"""
# early exit if we're not using any tangential/ thin prism distortions
if not self.use_tangential and not self.use_thin_prism:
return uv_distorted
xr_yr = uv_distorted
# do Newton iterations to find xr_yr
for _ in range(self.num_distortion_iters):
# compute the estimated uvDistorted
uv_distorted_est = xr_yr
xr_yr_squared_norm = torch.pow(xr_yr, 2).sum(dim=-1, keepdim=True)
if self.use_tangential:
temp = 2.0 * torch.sum(
xr_yr * tangential_params[..., 0:2],
dim=-1,
keepdim=True,
)
uv_distorted_est = uv_distorted_est + (
temp * xr_yr + xr_yr_squared_norm * tangential_params[..., 0:2]
)
if self.use_thin_prism:
radial_powers = torch.cat(
[xr_yr_squared_norm, xr_yr_squared_norm * xr_yr_squared_norm],
dim=-1,
)
uv_distorted_est[..., 0] = uv_distorted_est[..., 0] + torch.sum(
thin_prism_params[..., 0:2] * radial_powers,
dim=-1,
)
uv_distorted_est[..., 1] = uv_distorted_est[..., 1] + torch.sum(
thin_prism_params[..., 2:4] * radial_powers,
dim=-1,
)
# compute the derivative of uvDistorted wrt xr_yr
duv_distorted_dxryr = self._compute_duv_distorted_dxryr(
tangential_params, thin_prism_params, xr_yr, xr_yr_squared_norm[..., 0]
)
# compute correction:
# note: the matrix duvDistorted_dxryr will be close to identity (for reasonable
# values of tangential/thin prism distortions)
correction = torch.linalg.solve(
duv_distorted_dxryr, (uv_distorted - uv_distorted_est)[..., None]
)
xr_yr = xr_yr + correction[..., 0]
return xr_yr
def _get_theta_from_norm_xr_yr(
self, radial_params, th_radial_desired
) -> torch.Tensor:
"""
Helper function to compute the angle theta from the norm of the vector [x_r; y_r]
Args:
radial_params: k1, k2, ..., k_num_radial, (num_radial)
th_radial_desired: desired angle of shape (...), E.g., (P), (1, P), (M, P)
Returns:
th: angle theta (in radians) of shape (...), E.g., (P), (1, P), (M, P)
"""
sh = list(th_radial_desired.shape)
# th = th_radial_desired.clone()
th = th_radial_desired
c = torch.tensor(
[2.0 * i + 3 for i in range(self.num_radial)], device=self.device
)
for _ in range(self.num_distortion_iters):
theta_sq = th * th
th_radial = 1.0
dthD_dth = 1.0
# compute the theta polynomial and its derivative wrt theta
t = theta_sq
theta_pow = torch.stack([t, t**2, t**3, t**4, t**5, t**6], dim=-1)
th_radial = th_radial + torch.sum(theta_pow * radial_params, dim=-1)
dthD_dth = dthD_dth + torch.sum(c * radial_params * theta_pow, dim=-1)
th_radial = th_radial * th
# compute the correction
step = torch.zeros(*sh, device=self.device)
# make sure don't divide by zero
nonzero_mask = dthD_dth.abs() > self.epsilon
step = step + nonzero_mask * (th_radial_desired - th_radial) / dthD_dth
# if derivative is close to zero, apply small correction in the appropriate
# direction to avoid numerical explosions
close_to_zero_mask = dthD_dth.abs() <= self.epsilon
dir_mask = (th_radial_desired - th_radial) * dthD_dth > 0.0
boolean_mask = close_to_zero_mask & dir_mask
step = step + 10.0 * self.epsilon * boolean_mask
step = step - 10 * self.epsilon * (~nonzero_mask & ~boolean_mask)
# apply correction
th = th + step
# revert to within 180 degrees FOV to avoid numerical overflow
idw = th.abs() >= math.pi / 2.0
th[idw] = 0.999 * math.pi / 2.0
return th
def _compute_duv_distorted_dxryr(
self, tangential_params, thin_prism_params, xr_yr, xr_yr_squareNorm
) -> torch.Tensor:
"""
Helper function, computes the Jacobian of uvDistorted wrt the vector [x_r;y_r]
Args:
tangential_params: (2)
thin_prism_params: (4)
xr_yr: (P, 2)
xr_yr_squareNorm: (...), E.g., (P), (1, P), (M, P)
Returns:
duv_distorted_dxryr: (..., 2, 2) Jacobian
"""
sh = list(xr_yr.shape[:-1])
duv_distorted_dxryr = torch.empty((*sh, 2, 2), device=self.device)
if self.use_tangential:
duv_distorted_dxryr[..., 0, 0] = (
1.0
+ 6.0 * xr_yr[..., 0] * tangential_params[..., 0]
+ 2.0 * xr_yr[..., 1] * tangential_params[..., 1]
)
offdiag = 2.0 * (
xr_yr[..., 0] * tangential_params[..., 1]
+ xr_yr[..., 1] * tangential_params[..., 0]
)
duv_distorted_dxryr[..., 0, 1] = offdiag
duv_distorted_dxryr[..., 1, 0] = offdiag
duv_distorted_dxryr[..., 1, 1] = (
1.0
+ 6.0 * xr_yr[..., 1] * tangential_params[..., 1]
+ 2.0 * xr_yr[..., 0] * tangential_params[..., 0]
)
else:
duv_distorted_dxryr = torch.eye(2).view(*sh, 2, 2).expand(*sh, 1, 1)
if self.use_thin_prism:
temp1 = 2.0 * (
thin_prism_params[..., 0]
+ 2.0 * thin_prism_params[..., 1] * xr_yr_squareNorm[...]
)
duv_distorted_dxryr[..., 0, 0] = (
duv_distorted_dxryr[..., 0, 0] + xr_yr[..., 0] * temp1
)
duv_distorted_dxryr[..., 0, 1] = (
duv_distorted_dxryr[..., 0, 1] + xr_yr[..., 1] * temp1
)
temp2 = 2.0 * (
thin_prism_params[..., 2]
+ 2.0 * thin_prism_params[..., 3] * xr_yr_squareNorm[...]
)
duv_distorted_dxryr[..., 1, 0] = (
duv_distorted_dxryr[..., 1, 0] + xr_yr[..., 0] * temp2
)
duv_distorted_dxryr[..., 1, 1] = (
duv_distorted_dxryr[..., 1, 1] + xr_yr[..., 1] * temp2
)
return duv_distorted_dxryr
def in_ndc(self):
return True
def is_perspective(self):
return False

View File

@ -52,6 +52,7 @@ from pytorch3d.renderer.cameras import (
SfMOrthographicCameras,
SfMPerspectiveCameras,
)
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
from pytorch3d.transforms import Transform3d
from pytorch3d.transforms.rotation_conversions import random_rotations
from pytorch3d.transforms.so3 import so3_exp_map
@ -186,6 +187,12 @@ def init_random_cameras(
):
cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1
cam_params["principal_point"] = torch.randn((batch_size, 2))
elif cam_type == FishEyeCameras:
cam_params["focal_length"] = torch.rand(batch_size, 1) * 10 + 0.1
cam_params["principal_point"] = torch.randn((batch_size, 2))
cam_params["radial_params"] = torch.randn((batch_size, 6))
cam_params["tangential_params"] = torch.randn((batch_size, 2))
cam_params["thin_prism_params"] = torch.randn((batch_size, 4))
else:
raise ValueError(str(cam_type))
@ -196,7 +203,6 @@ class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
def test_look_at_view_transform_from_eye_point_tuple(self):
dist = math.sqrt(2)
@ -606,6 +612,22 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
)
self.assertTrue(torch.allclose(xyz_unproj, matching_xyz, atol=1e-4))
@staticmethod
def unproject_points(cam_type, batch_size=50, num_points=100):
"""
Checks that an unprojection of a randomly projected point cloud
stays the same.
"""
def run_cameras():
# init the cameras
cameras = init_random_cameras(cam_type, batch_size)
# xyz - the ground truth point cloud
xyz = torch.randn(num_points, 3) * 0.3
xyz = cameras.unproject_points(xyz, scaled_depth_input=True)
return run_cameras
def test_project_points_screen(self, batch_size=50, num_points=100):
"""
Checks that an unprojection of a randomly projected point cloud
@ -643,6 +665,24 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
# we set atol to 1e-4, remember that screen points are in [0, W]x[0, H] space
self.assertClose(xyz_project_screen, xyz_project_screen_naive, atol=1e-4)
@staticmethod
def transform_points(cam_type, batch_size=50, num_points=100):
"""
Checks that an unprojection of a randomly projected point cloud
stays the same.
"""
def run_cameras():
# init the cameras
cameras = init_random_cameras(cam_type, batch_size)
# xyz - the ground truth point cloud
xy = torch.randn(num_points, 2) * 2.0 - 1.0
z = torch.randn(num_points, 1) * 3.0 + 1.0
xyz = torch.cat((xy, z), dim=-1)
xy = cameras.transform_points(xyz)
return run_cameras
def test_equiv_project_points(self, batch_size=50, num_points=100):
"""
Checks that NDC and screen cameras project points to ndc correctly.
@ -1251,3 +1291,284 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
# Check in_ndc is handled correctly
self.assertEqual(cam._in_ndc, c0._in_ndc)
############################################################
# FishEye Camera #
############################################################
class TestFishEyeProjection(TestCaseMixin, unittest.TestCase):
def setUpSimpleCase(self) -> None:
super().setUp()
focal = torch.tensor([[240]], dtype=torch.float32)
principal_point = torch.tensor([[320, 240]])
p_3d = torch.tensor(
[
[2.0, 3.0, 1.0],
[3.0, 2.0, 1.0],
],
dtype=torch.float32,
)
return focal, principal_point, p_3d
def setUpAriaCase(self) -> None:
super().setUp()
torch.manual_seed(42)
focal = torch.tensor([[608.9255557152]], dtype=torch.float32)
principal_point = torch.tensor(
[[712.0114821205, 706.8666571177]], dtype=torch.float32
)
radial_params = torch.tensor(
[
[
0.3877090026,
-0.315613384,
-0.3434984955,
1.8565874201,
-2.1799372221,
0.7713834763,
],
],
dtype=torch.float32,
)
tangential_params = torch.tensor(
[[-0.0002747019, 0.0005228974]], dtype=torch.float32
)
thin_prism_params = torch.tensor(
[
[0.000134884, -0.000084822, -0.0009420014, -0.0001276838],
],
dtype=torch.float32,
)
return (
focal,
principal_point,
radial_params,
tangential_params,
thin_prism_params,
)
def setUpBatchCameras(self) -> None:
super().setUp()
focal, principal_point, p_3d = self.setUpSimpleCase()
radial_params = torch.tensor(
[
[0, 0, 0, 0, 0, 0],
],
dtype=torch.float32,
)
tangential_params = torch.tensor([[0, 0]], dtype=torch.float32)
thin_prism_params = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32)
(
focal1,
principal_point1,
radial_params1,
tangential_params1,
thin_prism_params1,
) = self.setUpAriaCase()
focal = torch.cat([focal, focal1], dim=0)
principal_point = torch.cat([principal_point, principal_point1], dim=0)
radial_params = torch.cat([radial_params, radial_params1], dim=0)
tangential_params = torch.cat([tangential_params, tangential_params1], dim=0)
thin_prism_params = torch.cat([thin_prism_params, thin_prism_params1], dim=0)
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
radial_params=radial_params,
tangential_params=tangential_params,
thin_prism_params=thin_prism_params,
)
return cameras
def test_distortion_params_set_to_zeors(self):
# test case 1: all distortion params are 0. Note that
# setting radial_params to zeros is not equivalent to
# disabling radial distortions, set use_radial=False does
focal, principal_point, p_3d = self.setUpSimpleCase()
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
)
uv_case1 = cameras.transform_points(p_3d)
self.assertClose(
uv_case1,
torch.tensor(
[[493.0993, 499.6489, 1.0], [579.6489, 413.0993, 1.0]],
),
)
# test case 2: equivalent of test case 1 by
# disabling use_tangential and use_thin_prism
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
use_tangential=False,
use_thin_prism=False,
)
uv_case2 = cameras.transform_points(p_3d)
self.assertClose(uv_case2, uv_case1)
def test_fisheye_against_perspective_cameras(self):
# test case: check equivalence with PerspectiveCameras
# by disabling all distortions
focal, principal_point, p_3d = self.setUpSimpleCase()
cameras = PerspectiveCameras(
focal_length=focal,
principal_point=principal_point,
)
P = cameras.get_projection_transform()
uv_perspective = P.transform_points(p_3d)
# disable all distortions
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
use_radial=False,
use_tangential=False,
use_thin_prism=False,
)
uv = cameras.transform_points(p_3d)
self.assertClose(uv, uv_perspective)
def test_project_shape_broadcasts(self):
focal, principal_point, p_3d = self.setUpSimpleCase()
# test case 1:
# 1 transform with points of shape (P, 3) -> (P, 3)
# 1 transform with points of shape (1, P, 3) -> (1, P, 3)
# 1 transform with points of shape (M, P, 3) -> (M, P, 3)
points = p_3d.repeat(1, 1, 1)
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
use_radial=False,
use_tangential=False,
use_thin_prism=False,
)
uv = cameras.transform_points(p_3d)
uv_point_batch = cameras.transform_points(points)
self.assertClose(uv_point_batch, uv.repeat(1, 1, 1))
points = p_3d.repeat(3, 1, 1)
uv_point_batch = cameras.transform_points(points)
self.assertClose(uv_point_batch, uv.repeat(3, 1, 1))
# test case 2
# test with N transforms and points of shape (P, 3) -> (N, P, 3)
# test with N transforms and points of shape (1, P, 3) -> (N, P, 3)
# first camera transform params
cameras = self.setUpBatchCameras()
p_3d = torch.tensor(
[
[2.0, 3.0, 1.0],
[2.0, 3.0, 1.0],
[3.0, 2.0, 1.0],
]
)
expected_res = torch.tensor(
[
[
[493.0993, 499.6489, 1.0],
[493.0993, 499.6489, 1.0],
[579.6489, 413.0993, 1.0],
],
[
[1660.2700, 2128.2273, 1.0],
[1660.2700, 2128.2273, 1.0],
[2134.5815, 1650.9565, 1.0],
],
]
)
uv_point_batch = cameras.transform_points(p_3d)
self.assertClose(uv_point_batch, expected_res)
uv_point_batch = cameras.transform_points(p_3d.repeat(1, 1, 1))
self.assertClose(uv_point_batch, expected_res)
def test_cuda(self):
"""
Test cuda device
"""
focal, principal_point, p_3d = self.setUpSimpleCase()
cameras_cuda = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
device="cuda:0",
)
uv = cameras_cuda.transform_points(p_3d)
expected_res = torch.tensor(
[[493.0993, 499.6489, 1.0], [579.6489, 413.0993, 1.0]],
)
self.assertClose(uv, expected_res.to("cuda:0"))
rep_3d = cameras_cuda.unproject_points(uv)
self.assertClose(rep_3d, p_3d.to("cuda:0"))
def test_unproject_shape_broadcasts(self):
# test case 1:
# 1 transform with points of (P, 3) -> (P, 3)
# 1 transform with points of (M, P, 3) -> (M, P, 3)
(
focal,
principal_point,
radial_params,
tangential_params,
thin_prism_params,
) = self.setUpAriaCase()
xy_depth = torch.tensor(
[
[2134.5814033, 1650.95653328, 1.0],
[1074.25442904, 1159.52461285, 1.0],
]
)
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
radial_params=radial_params,
tangential_params=tangential_params,
thin_prism_params=thin_prism_params,
)
rep_3d = cameras.unproject_points(xy_depth)
expected_res = torch.tensor(
[
[3.0000, 2.0000, 1.0000],
[0.666667, 0.833333, 1.0000],
],
)
self.assertClose(rep_3d, expected_res)
rep_3d = cameras.unproject_points(xy_depth.repeat(3, 1, 1))
self.assertClose(rep_3d, expected_res.repeat(3, 1, 1))
# test case 2:
# N transforms with points of (P, 3) -> (N, P, 3)
# N transforms with points of (1, P, 3) -> (N, P, 3)
cameras = FishEyeCameras(
focal_length=focal.repeat(2, 1),
principal_point=principal_point.repeat(2, 1),
radial_params=radial_params.repeat(2, 1),
tangential_params=tangential_params.repeat(2, 1),
thin_prism_params=thin_prism_params.repeat(2, 1),
)
rep_3d = cameras.unproject_points(xy_depth)
self.assertClose(rep_3d, expected_res.repeat(2, 1, 1))
def test_unhandled_shape(self):
"""
Test error handling when shape of transforms
and points are not expected.
"""
cameras = self.setUpBatchCameras()
points = torch.rand(3, 3, 1)
with self.assertRaises(ValueError):
cameras.transform_points(points)
def test_getitem(self):
# Check get item returns an instance of the same class
# with all the same keys
cam = self.setUpBatchCameras()
c0 = cam[0]
self.assertTrue(isinstance(c0, FishEyeCameras))
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())