mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Amend D38407094: Fisheye Camera for PyTorch3D
Summary: Amend FisheyeCamera by adding tests for all combination of params and for different batch_sizes. Reviewed By: kjchalup Differential Revision: D39176747 fbshipit-source-id: 830d30da24beeb2f0df52db0b17a4303ed53b59c
This commit is contained in:
parent
d4a1051e0f
commit
b0515e1461
@ -425,7 +425,7 @@ class FishEyeCameras(CamerasBase):
|
||||
# do Newton iterations to find xr_yr
|
||||
for _ in range(self.num_distortion_iters):
|
||||
# compute the estimated uvDistorted
|
||||
uv_distorted_est = xr_yr
|
||||
uv_distorted_est = xr_yr.clone()
|
||||
xr_yr_squared_norm = torch.pow(xr_yr, 2).sum(dim=-1, keepdim=True)
|
||||
|
||||
if self.use_tangential:
|
||||
@ -479,7 +479,6 @@ class FishEyeCameras(CamerasBase):
|
||||
th: angle theta (in radians) of shape (...), E.g., (P), (1, P), (M, P)
|
||||
"""
|
||||
sh = list(th_radial_desired.shape)
|
||||
# th = th_radial_desired.clone()
|
||||
th = th_radial_desired
|
||||
c = torch.tensor(
|
||||
[2.0 * i + 3 for i in range(self.num_radial)], device=self.device
|
||||
@ -552,7 +551,7 @@ class FishEyeCameras(CamerasBase):
|
||||
+ 2.0 * xr_yr[..., 0] * tangential_params[..., 0]
|
||||
)
|
||||
else:
|
||||
duv_distorted_dxryr = torch.eye(2).view(*sh, 2, 2).expand(*sh, 1, 1)
|
||||
duv_distorted_dxryr = torch.eye(2).repeat(*sh, 1, 1)
|
||||
|
||||
if self.use_thin_prism:
|
||||
temp1 = 2.0 * (
|
||||
|
@ -33,6 +33,7 @@
|
||||
import math
|
||||
import typing
|
||||
import unittest
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -1389,7 +1390,7 @@ class TestFishEyeProjection(TestCaseMixin, unittest.TestCase):
|
||||
thin_prism_params,
|
||||
)
|
||||
|
||||
def setUpBatchCameras(self) -> None:
|
||||
def setUpBatchCameras(self, combination: None) -> None:
|
||||
super().setUp()
|
||||
focal, principal_point, p_3d = self.setUpSimpleCase()
|
||||
radial_params = torch.tensor(
|
||||
@ -1412,8 +1413,12 @@ class TestFishEyeProjection(TestCaseMixin, unittest.TestCase):
|
||||
radial_params = torch.cat([radial_params, radial_params1], dim=0)
|
||||
tangential_params = torch.cat([tangential_params, tangential_params1], dim=0)
|
||||
thin_prism_params = torch.cat([thin_prism_params, thin_prism_params1], dim=0)
|
||||
|
||||
if combination is None:
|
||||
combination = [True, True, True]
|
||||
cameras = FishEyeCameras(
|
||||
use_radial=combination[0],
|
||||
use_tangential=combination[1],
|
||||
use_thin_prism=combination[2],
|
||||
focal_length=focal,
|
||||
principal_point=principal_point,
|
||||
radial_params=radial_params,
|
||||
@ -1474,21 +1479,31 @@ class TestFishEyeProjection(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_project_shape_broadcasts(self):
|
||||
focal, principal_point, p_3d = self.setUpSimpleCase()
|
||||
# test case 1:
|
||||
# 1 transform with points of shape (P, 3) -> (P, 3)
|
||||
# 1 transform with points of shape (1, P, 3) -> (1, P, 3)
|
||||
# 1 transform with points of shape (M, P, 3) -> (M, P, 3)
|
||||
points = p_3d.repeat(1, 1, 1)
|
||||
cameras = FishEyeCameras(
|
||||
focal_length=focal,
|
||||
principal_point=principal_point,
|
||||
use_radial=False,
|
||||
use_tangential=False,
|
||||
use_thin_prism=False,
|
||||
)
|
||||
uv = cameras.transform_points(p_3d)
|
||||
uv_point_batch = cameras.transform_points(points)
|
||||
self.assertClose(uv_point_batch, uv.repeat(1, 1, 1))
|
||||
torch.set_printoptions(precision=6)
|
||||
combinations = product([0, 1], repeat=3)
|
||||
for combination in combinations:
|
||||
cameras = FishEyeCameras(
|
||||
use_radial=combination[0],
|
||||
use_tangential=combination[1],
|
||||
use_thin_prism=combination[2],
|
||||
focal_length=focal,
|
||||
principal_point=principal_point,
|
||||
)
|
||||
# test case 1:
|
||||
# 1 transform with points of shape (P, 3) -> (P, 3)
|
||||
# 1 transform with points of shape (1, P, 3) -> (1, P, 3)
|
||||
# 1 transform with points of shape (M, P, 3) -> (M, P, 3)
|
||||
points = p_3d.repeat(1, 1, 1)
|
||||
cameras = FishEyeCameras(
|
||||
focal_length=focal,
|
||||
principal_point=principal_point,
|
||||
use_radial=False,
|
||||
use_tangential=False,
|
||||
use_thin_prism=False,
|
||||
)
|
||||
uv = cameras.transform_points(p_3d)
|
||||
uv_point_batch = cameras.transform_points(points)
|
||||
self.assertClose(uv_point_batch, uv.repeat(1, 1, 1))
|
||||
|
||||
points = p_3d.repeat(3, 1, 1)
|
||||
uv_point_batch = cameras.transform_points(points)
|
||||
@ -1497,12 +1512,9 @@ class TestFishEyeProjection(TestCaseMixin, unittest.TestCase):
|
||||
# test case 2
|
||||
# test with N transforms and points of shape (P, 3) -> (N, P, 3)
|
||||
# test with N transforms and points of shape (1, P, 3) -> (N, P, 3)
|
||||
|
||||
# first camera transform params
|
||||
cameras = self.setUpBatchCameras()
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
p_3d = torch.tensor(
|
||||
[
|
||||
[2.0, 3.0, 1.0],
|
||||
[2.0, 3.0, 1.0],
|
||||
[3.0, 2.0, 1.0],
|
||||
]
|
||||
@ -1510,22 +1522,95 @@ class TestFishEyeProjection(TestCaseMixin, unittest.TestCase):
|
||||
expected_res = torch.tensor(
|
||||
[
|
||||
[
|
||||
[493.0993, 499.6489, 1.0],
|
||||
[493.0993, 499.6489, 1.0],
|
||||
[579.6489, 413.0993, 1.0],
|
||||
[
|
||||
[800.000000, 960.000000, 1.000000],
|
||||
[1040.000000, 720.000000, 1.000000],
|
||||
],
|
||||
[
|
||||
[1929.862549, 2533.643311, 1.000000],
|
||||
[2538.788086, 1924.717773, 1.000000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[1660.2700, 2128.2273, 1.0],
|
||||
[1660.2700, 2128.2273, 1.0],
|
||||
[2134.5815, 1650.9565, 1.0],
|
||||
[
|
||||
[800.000000, 960.000000, 1.000000],
|
||||
[1040.000000, 720.000000, 1.000000],
|
||||
],
|
||||
[
|
||||
[1927.272095, 2524.220459, 1.000000],
|
||||
[2536.197754, 1915.295166, 1.000000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[800.000000, 960.000000, 1.000000],
|
||||
[1040.000000, 720.000000, 1.000000],
|
||||
],
|
||||
[
|
||||
[1930.050293, 2538.434814, 1.000000],
|
||||
[2537.956543, 1927.569092, 1.000000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[800.000000, 960.000000, 1.000000],
|
||||
[1040.000000, 720.000000, 1.000000],
|
||||
],
|
||||
[
|
||||
[1927.459839, 2529.011963, 1.000000],
|
||||
[2535.366211, 1918.146484, 1.000000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[493.099304, 499.648926, 1.000000],
|
||||
[579.648926, 413.099304, 1.000000],
|
||||
],
|
||||
[
|
||||
[1662.673950, 2132.860352, 1.000000],
|
||||
[2138.005127, 1657.529053, 1.000000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[493.099304, 499.648926, 1.000000],
|
||||
[579.648926, 413.099304, 1.000000],
|
||||
],
|
||||
[
|
||||
[1660.083496, 2123.437744, 1.000000],
|
||||
[2135.414795, 1648.106445, 1.000000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[493.099304, 499.648926, 1.000000],
|
||||
[579.648926, 413.099304, 1.000000],
|
||||
],
|
||||
[
|
||||
[1662.861816, 2137.651855, 1.000000],
|
||||
[2137.173828, 1660.380371, 1.000000],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[493.099304, 499.648926, 1.000000],
|
||||
[579.648926, 413.099304, 1.000000],
|
||||
],
|
||||
[
|
||||
[1660.271240, 2128.229248, 1.000000],
|
||||
[2134.583496, 1650.957764, 1.000000],
|
||||
],
|
||||
],
|
||||
]
|
||||
)
|
||||
uv_point_batch = cameras.transform_points(p_3d)
|
||||
self.assertClose(uv_point_batch, expected_res)
|
||||
combinations = product([0, 1], repeat=3)
|
||||
for i, combination in enumerate(combinations):
|
||||
cameras = self.setUpBatchCameras(combination)
|
||||
uv_point_batch = cameras.transform_points(p_3d)
|
||||
self.assertClose(uv_point_batch, expected_res[i])
|
||||
|
||||
uv_point_batch = cameras.transform_points(p_3d.repeat(1, 1, 1))
|
||||
self.assertClose(uv_point_batch, expected_res)
|
||||
uv_point_batch = cameras.transform_points(p_3d.repeat(1, 1, 1))
|
||||
self.assertClose(uv_point_batch, expected_res[i].repeat(1, 1, 1))
|
||||
|
||||
def test_cuda(self):
|
||||
"""
|
||||
@ -1573,34 +1658,56 @@ class TestFishEyeProjection(TestCaseMixin, unittest.TestCase):
|
||||
rep_3d = cameras.unproject_points(xy_depth)
|
||||
expected_res = torch.tensor(
|
||||
[
|
||||
[3.0000, 2.0000, 1.0000],
|
||||
[0.666667, 0.833333, 1.0000],
|
||||
],
|
||||
[[2.999442, 1.990583, 1.000000], [0.666728, 0.833142, 1.000000]],
|
||||
[[2.997338, 2.005411, 1.000000], [0.666859, 0.834456, 1.000000]],
|
||||
[[3.002090, 1.985229, 1.000000], [0.666537, 0.832025, 1.000000]],
|
||||
[[2.999999, 2.000000, 1.000000], [0.666667, 0.833333, 1.000000]],
|
||||
[[2.999442, 1.990583, 1.000000], [0.666728, 0.833142, 1.000000]],
|
||||
[[2.997338, 2.005411, 1.000000], [0.666859, 0.834456, 1.000000]],
|
||||
[[3.002090, 1.985229, 1.000000], [0.666537, 0.832025, 1.000000]],
|
||||
[[2.999999, 2.000000, 1.000000], [0.666667, 0.833333, 1.000000]],
|
||||
]
|
||||
)
|
||||
self.assertClose(rep_3d, expected_res)
|
||||
torch.set_printoptions(precision=6)
|
||||
combinations = product([0, 1], repeat=3)
|
||||
for i, combination in enumerate(combinations):
|
||||
cameras = FishEyeCameras(
|
||||
use_radial=combination[0],
|
||||
use_tangential=combination[1],
|
||||
use_thin_prism=combination[2],
|
||||
focal_length=focal,
|
||||
principal_point=principal_point,
|
||||
radial_params=radial_params,
|
||||
tangential_params=tangential_params,
|
||||
thin_prism_params=thin_prism_params,
|
||||
)
|
||||
rep_3d = cameras.unproject_points(xy_depth)
|
||||
self.assertClose(rep_3d, expected_res[i])
|
||||
rep_3d = cameras.unproject_points(xy_depth.repeat(3, 1, 1))
|
||||
self.assertClose(rep_3d, expected_res[i].repeat(3, 1, 1))
|
||||
|
||||
rep_3d = cameras.unproject_points(xy_depth.repeat(3, 1, 1))
|
||||
self.assertClose(rep_3d, expected_res.repeat(3, 1, 1))
|
||||
|
||||
# test case 2:
|
||||
# N transforms with points of (P, 3) -> (N, P, 3)
|
||||
# N transforms with points of (1, P, 3) -> (N, P, 3)
|
||||
cameras = FishEyeCameras(
|
||||
focal_length=focal.repeat(2, 1),
|
||||
principal_point=principal_point.repeat(2, 1),
|
||||
radial_params=radial_params.repeat(2, 1),
|
||||
tangential_params=tangential_params.repeat(2, 1),
|
||||
thin_prism_params=thin_prism_params.repeat(2, 1),
|
||||
)
|
||||
rep_3d = cameras.unproject_points(xy_depth)
|
||||
self.assertClose(rep_3d, expected_res.repeat(2, 1, 1))
|
||||
# test case 2:
|
||||
# N transforms with points of (P, 3) -> (N, P, 3)
|
||||
# N transforms with points of (1, P, 3) -> (N, P, 3)
|
||||
cameras = FishEyeCameras(
|
||||
use_radial=combination[0],
|
||||
use_tangential=combination[1],
|
||||
use_thin_prism=combination[2],
|
||||
focal_length=focal.repeat(2, 1),
|
||||
principal_point=principal_point.repeat(2, 1),
|
||||
radial_params=radial_params.repeat(2, 1),
|
||||
tangential_params=tangential_params.repeat(2, 1),
|
||||
thin_prism_params=thin_prism_params.repeat(2, 1),
|
||||
)
|
||||
rep_3d = cameras.unproject_points(xy_depth)
|
||||
self.assertClose(rep_3d, expected_res[i].repeat(2, 1, 1))
|
||||
|
||||
def test_unhandled_shape(self):
|
||||
"""
|
||||
Test error handling when shape of transforms
|
||||
and points are not expected.
|
||||
"""
|
||||
cameras = self.setUpBatchCameras()
|
||||
cameras = self.setUpBatchCameras(None)
|
||||
points = torch.rand(3, 3, 1)
|
||||
with self.assertRaises(ValueError):
|
||||
cameras.transform_points(points)
|
||||
@ -1608,7 +1715,7 @@ class TestFishEyeProjection(TestCaseMixin, unittest.TestCase):
|
||||
def test_getitem(self):
|
||||
# Check get item returns an instance of the same class
|
||||
# with all the same keys
|
||||
cam = self.setUpBatchCameras()
|
||||
cam = self.setUpBatchCameras(None)
|
||||
c0 = cam[0]
|
||||
self.assertTrue(isinstance(c0, FishEyeCameras))
|
||||
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
|
||||
|
Loading…
x
Reference in New Issue
Block a user