Fisheye Camera for PyTorch3D

Summary:
1. A Fisheye camera model that generalizes pinhole camera by considering distortions (i.e. radial, tangential and thin-prism distortions).

2. Added tests against perspective cameras when distortions are off and Aria data points when distortions are on.

3. Address comments to test unhandled shapes between points and transforms. Added tests for __FIELDS, shape broadcasts, cuda etc.

4. Address earlier comments for code efficiency (e.g., adopted torch.norm; torch.solve for matrix inverse; removed inplace operations; unnecessary clone; expand in place of repeat etc).

Reviewed By: jcjohnson

Differential Revision: D38407094

fbshipit-source-id: a3ab48c85c496ac87af692d5d461bb3fc2a2db13
This commit is contained in:
Jiali Duan
2022-08-28 11:17:20 -07:00
committed by Facebook GitHub Bot
parent 4711d12a09
commit 2283c292a9
2 changed files with 907 additions and 1 deletions

View File

@@ -52,6 +52,7 @@ from pytorch3d.renderer.cameras import (
SfMOrthographicCameras,
SfMPerspectiveCameras,
)
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
from pytorch3d.transforms import Transform3d
from pytorch3d.transforms.rotation_conversions import random_rotations
from pytorch3d.transforms.so3 import so3_exp_map
@@ -186,6 +187,12 @@ def init_random_cameras(
):
cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1
cam_params["principal_point"] = torch.randn((batch_size, 2))
elif cam_type == FishEyeCameras:
cam_params["focal_length"] = torch.rand(batch_size, 1) * 10 + 0.1
cam_params["principal_point"] = torch.randn((batch_size, 2))
cam_params["radial_params"] = torch.randn((batch_size, 6))
cam_params["tangential_params"] = torch.randn((batch_size, 2))
cam_params["thin_prism_params"] = torch.randn((batch_size, 4))
else:
raise ValueError(str(cam_type))
@@ -196,7 +203,6 @@ class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
def test_look_at_view_transform_from_eye_point_tuple(self):
dist = math.sqrt(2)
@@ -606,6 +612,22 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
)
self.assertTrue(torch.allclose(xyz_unproj, matching_xyz, atol=1e-4))
@staticmethod
def unproject_points(cam_type, batch_size=50, num_points=100):
"""
Checks that an unprojection of a randomly projected point cloud
stays the same.
"""
def run_cameras():
# init the cameras
cameras = init_random_cameras(cam_type, batch_size)
# xyz - the ground truth point cloud
xyz = torch.randn(num_points, 3) * 0.3
xyz = cameras.unproject_points(xyz, scaled_depth_input=True)
return run_cameras
def test_project_points_screen(self, batch_size=50, num_points=100):
"""
Checks that an unprojection of a randomly projected point cloud
@@ -643,6 +665,24 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
# we set atol to 1e-4, remember that screen points are in [0, W]x[0, H] space
self.assertClose(xyz_project_screen, xyz_project_screen_naive, atol=1e-4)
@staticmethod
def transform_points(cam_type, batch_size=50, num_points=100):
"""
Checks that an unprojection of a randomly projected point cloud
stays the same.
"""
def run_cameras():
# init the cameras
cameras = init_random_cameras(cam_type, batch_size)
# xyz - the ground truth point cloud
xy = torch.randn(num_points, 2) * 2.0 - 1.0
z = torch.randn(num_points, 1) * 3.0 + 1.0
xyz = torch.cat((xy, z), dim=-1)
xy = cameras.transform_points(xyz)
return run_cameras
def test_equiv_project_points(self, batch_size=50, num_points=100):
"""
Checks that NDC and screen cameras project points to ndc correctly.
@@ -1251,3 +1291,284 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
# Check in_ndc is handled correctly
self.assertEqual(cam._in_ndc, c0._in_ndc)
############################################################
# FishEye Camera #
############################################################
class TestFishEyeProjection(TestCaseMixin, unittest.TestCase):
def setUpSimpleCase(self) -> None:
super().setUp()
focal = torch.tensor([[240]], dtype=torch.float32)
principal_point = torch.tensor([[320, 240]])
p_3d = torch.tensor(
[
[2.0, 3.0, 1.0],
[3.0, 2.0, 1.0],
],
dtype=torch.float32,
)
return focal, principal_point, p_3d
def setUpAriaCase(self) -> None:
super().setUp()
torch.manual_seed(42)
focal = torch.tensor([[608.9255557152]], dtype=torch.float32)
principal_point = torch.tensor(
[[712.0114821205, 706.8666571177]], dtype=torch.float32
)
radial_params = torch.tensor(
[
[
0.3877090026,
-0.315613384,
-0.3434984955,
1.8565874201,
-2.1799372221,
0.7713834763,
],
],
dtype=torch.float32,
)
tangential_params = torch.tensor(
[[-0.0002747019, 0.0005228974]], dtype=torch.float32
)
thin_prism_params = torch.tensor(
[
[0.000134884, -0.000084822, -0.0009420014, -0.0001276838],
],
dtype=torch.float32,
)
return (
focal,
principal_point,
radial_params,
tangential_params,
thin_prism_params,
)
def setUpBatchCameras(self) -> None:
super().setUp()
focal, principal_point, p_3d = self.setUpSimpleCase()
radial_params = torch.tensor(
[
[0, 0, 0, 0, 0, 0],
],
dtype=torch.float32,
)
tangential_params = torch.tensor([[0, 0]], dtype=torch.float32)
thin_prism_params = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32)
(
focal1,
principal_point1,
radial_params1,
tangential_params1,
thin_prism_params1,
) = self.setUpAriaCase()
focal = torch.cat([focal, focal1], dim=0)
principal_point = torch.cat([principal_point, principal_point1], dim=0)
radial_params = torch.cat([radial_params, radial_params1], dim=0)
tangential_params = torch.cat([tangential_params, tangential_params1], dim=0)
thin_prism_params = torch.cat([thin_prism_params, thin_prism_params1], dim=0)
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
radial_params=radial_params,
tangential_params=tangential_params,
thin_prism_params=thin_prism_params,
)
return cameras
def test_distortion_params_set_to_zeors(self):
# test case 1: all distortion params are 0. Note that
# setting radial_params to zeros is not equivalent to
# disabling radial distortions, set use_radial=False does
focal, principal_point, p_3d = self.setUpSimpleCase()
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
)
uv_case1 = cameras.transform_points(p_3d)
self.assertClose(
uv_case1,
torch.tensor(
[[493.0993, 499.6489, 1.0], [579.6489, 413.0993, 1.0]],
),
)
# test case 2: equivalent of test case 1 by
# disabling use_tangential and use_thin_prism
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
use_tangential=False,
use_thin_prism=False,
)
uv_case2 = cameras.transform_points(p_3d)
self.assertClose(uv_case2, uv_case1)
def test_fisheye_against_perspective_cameras(self):
# test case: check equivalence with PerspectiveCameras
# by disabling all distortions
focal, principal_point, p_3d = self.setUpSimpleCase()
cameras = PerspectiveCameras(
focal_length=focal,
principal_point=principal_point,
)
P = cameras.get_projection_transform()
uv_perspective = P.transform_points(p_3d)
# disable all distortions
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
use_radial=False,
use_tangential=False,
use_thin_prism=False,
)
uv = cameras.transform_points(p_3d)
self.assertClose(uv, uv_perspective)
def test_project_shape_broadcasts(self):
focal, principal_point, p_3d = self.setUpSimpleCase()
# test case 1:
# 1 transform with points of shape (P, 3) -> (P, 3)
# 1 transform with points of shape (1, P, 3) -> (1, P, 3)
# 1 transform with points of shape (M, P, 3) -> (M, P, 3)
points = p_3d.repeat(1, 1, 1)
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
use_radial=False,
use_tangential=False,
use_thin_prism=False,
)
uv = cameras.transform_points(p_3d)
uv_point_batch = cameras.transform_points(points)
self.assertClose(uv_point_batch, uv.repeat(1, 1, 1))
points = p_3d.repeat(3, 1, 1)
uv_point_batch = cameras.transform_points(points)
self.assertClose(uv_point_batch, uv.repeat(3, 1, 1))
# test case 2
# test with N transforms and points of shape (P, 3) -> (N, P, 3)
# test with N transforms and points of shape (1, P, 3) -> (N, P, 3)
# first camera transform params
cameras = self.setUpBatchCameras()
p_3d = torch.tensor(
[
[2.0, 3.0, 1.0],
[2.0, 3.0, 1.0],
[3.0, 2.0, 1.0],
]
)
expected_res = torch.tensor(
[
[
[493.0993, 499.6489, 1.0],
[493.0993, 499.6489, 1.0],
[579.6489, 413.0993, 1.0],
],
[
[1660.2700, 2128.2273, 1.0],
[1660.2700, 2128.2273, 1.0],
[2134.5815, 1650.9565, 1.0],
],
]
)
uv_point_batch = cameras.transform_points(p_3d)
self.assertClose(uv_point_batch, expected_res)
uv_point_batch = cameras.transform_points(p_3d.repeat(1, 1, 1))
self.assertClose(uv_point_batch, expected_res)
def test_cuda(self):
"""
Test cuda device
"""
focal, principal_point, p_3d = self.setUpSimpleCase()
cameras_cuda = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
device="cuda:0",
)
uv = cameras_cuda.transform_points(p_3d)
expected_res = torch.tensor(
[[493.0993, 499.6489, 1.0], [579.6489, 413.0993, 1.0]],
)
self.assertClose(uv, expected_res.to("cuda:0"))
rep_3d = cameras_cuda.unproject_points(uv)
self.assertClose(rep_3d, p_3d.to("cuda:0"))
def test_unproject_shape_broadcasts(self):
# test case 1:
# 1 transform with points of (P, 3) -> (P, 3)
# 1 transform with points of (M, P, 3) -> (M, P, 3)
(
focal,
principal_point,
radial_params,
tangential_params,
thin_prism_params,
) = self.setUpAriaCase()
xy_depth = torch.tensor(
[
[2134.5814033, 1650.95653328, 1.0],
[1074.25442904, 1159.52461285, 1.0],
]
)
cameras = FishEyeCameras(
focal_length=focal,
principal_point=principal_point,
radial_params=radial_params,
tangential_params=tangential_params,
thin_prism_params=thin_prism_params,
)
rep_3d = cameras.unproject_points(xy_depth)
expected_res = torch.tensor(
[
[3.0000, 2.0000, 1.0000],
[0.666667, 0.833333, 1.0000],
],
)
self.assertClose(rep_3d, expected_res)
rep_3d = cameras.unproject_points(xy_depth.repeat(3, 1, 1))
self.assertClose(rep_3d, expected_res.repeat(3, 1, 1))
# test case 2:
# N transforms with points of (P, 3) -> (N, P, 3)
# N transforms with points of (1, P, 3) -> (N, P, 3)
cameras = FishEyeCameras(
focal_length=focal.repeat(2, 1),
principal_point=principal_point.repeat(2, 1),
radial_params=radial_params.repeat(2, 1),
tangential_params=tangential_params.repeat(2, 1),
thin_prism_params=thin_prism_params.repeat(2, 1),
)
rep_3d = cameras.unproject_points(xy_depth)
self.assertClose(rep_3d, expected_res.repeat(2, 1, 1))
def test_unhandled_shape(self):
"""
Test error handling when shape of transforms
and points are not expected.
"""
cameras = self.setUpBatchCameras()
points = torch.rand(3, 3, 1)
with self.assertRaises(ValueError):
cameras.transform_points(points)
def test_getitem(self):
# Check get item returns an instance of the same class
# with all the same keys
cam = self.setUpBatchCameras()
c0 = cam[0]
self.assertTrue(isinstance(c0, FishEyeCameras))
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())