camera refactoring

Summary:
Refactor cameras
* CamerasBase was enhanced with `transform_points_screen` that transforms projected points from NDC to screen space
* OpenGLPerspective, OpenGLOrthographic -> FoVPerspective, FoVOrthographic
* SfMPerspective, SfMOrthographic -> Perspective, Orthographic
* PerspectiveCamera can optionally be constructred with screen space parameters
* Note on Cameras and coordinate systems was added

Reviewed By: nikhilaravi

Differential Revision: D23168525

fbshipit-source-id: dd138e2b2cc7e0e0d9f34c45b8251c01266a2063
This commit is contained in:
Georgia Gkioxari
2020-08-20 22:20:41 -07:00
committed by Facebook GitHub Bot
parent 9242e7e65d
commit 57a22e7306
65 changed files with 896 additions and 279 deletions

View File

@@ -31,12 +31,16 @@ import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.cameras import OpenGLOrthographicCameras # deprecated
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras # deprecated
from pytorch3d.renderer.cameras import SfMOrthographicCameras # deprecated
from pytorch3d.renderer.cameras import SfMPerspectiveCameras # deprecated
from pytorch3d.renderer.cameras import (
CamerasBase,
OpenGLOrthographicCameras,
OpenGLPerspectiveCameras,
SfMOrthographicCameras,
SfMPerspectiveCameras,
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
camera_position_from_spherical_angles,
get_world_to_view_transform,
look_at_rotation,
@@ -109,6 +113,25 @@ def orthographic_project_naive(points, scale_xyz=(1.0, 1.0, 1.0)):
return points
def ndc_to_screen_points_naive(points, imsize):
"""
Transforms points from PyTorch3D's NDC space to screen space
Args:
points: (N, V, 3) representing padded points
imsize: (N, 2) image size = (width, height)
Returns:
(N, V, 3) tensor of transformed points
"""
imwidth, imheight = imsize.unbind(1)
imwidth = imwidth.view(-1, 1)
imheight = imheight.view(-1, 1)
x, y, z = points.unbind(2)
x = (1.0 - x) * (imwidth - 1) / 2.0
y = (1.0 - y) * (imheight - 1) / 2.0
return torch.stack((x, y, z), dim=2)
class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
@@ -359,6 +382,10 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
OpenGLOrthographicCameras,
SfMOrthographicCameras,
SfMPerspectiveCameras,
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
):
cam = cam_type(R=R, T=T)
RT_class = cam.get_world_to_view_transform()
@@ -374,6 +401,10 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
OpenGLOrthographicCameras,
SfMOrthographicCameras,
SfMPerspectiveCameras,
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
):
cam = cam_type(R=R, T=T)
C = cam.get_camera_center()
@@ -398,13 +429,53 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9
elif cam_type in (SfMOrthographicCameras, SfMPerspectiveCameras):
elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras):
cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
if cam_type == FoVPerspectiveCameras:
cam_params["fov"] = torch.rand(batch_size) * 60 + 30
cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
else:
cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9
cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9
elif cam_type in (
SfMOrthographicCameras,
SfMPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
):
cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1
cam_params["principal_point"] = torch.randn((batch_size, 2))
else:
raise ValueError(str(cam_type))
return cam_type(**cam_params)
@staticmethod
def init_equiv_cameras_ndc_screen(cam_type: CamerasBase, batch_size: int):
T = torch.randn(batch_size, 3) * 0.03
T[:, 2] = 4
R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0)
screen_cam_params = {"R": R, "T": T}
ndc_cam_params = {"R": R, "T": T}
if cam_type in (OrthographicCameras, PerspectiveCameras):
ndc_cam_params["focal_length"] = torch.rand((batch_size, 2)) * 3.0
ndc_cam_params["principal_point"] = torch.randn((batch_size, 2))
image_size = torch.randint(low=2, high=64, size=(batch_size, 2))
screen_cam_params["image_size"] = image_size
screen_cam_params["focal_length"] = (
ndc_cam_params["focal_length"] * image_size / 2.0
)
screen_cam_params["principal_point"] = (
(1.0 - ndc_cam_params["principal_point"]) * image_size / 2.0
)
else:
raise ValueError(str(cam_type))
return cam_type(**ndc_cam_params), cam_type(**screen_cam_params)
def test_unproject_points(self, batch_size=50, num_points=100):
"""
Checks that an unprojection of a randomly projected point cloud
@@ -416,6 +487,10 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
OpenGLPerspectiveCameras,
OpenGLOrthographicCameras,
SfMPerspectiveCameras,
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
):
# init the cameras
cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size)
@@ -437,9 +512,14 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
else:
matching_xyz = xyz_cam
# if we have OpenGL cameras
# if we have FoV (= OpenGL) cameras
# test for scaled_depth_input=True/False
if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras):
if cam_type in (
OpenGLPerspectiveCameras,
OpenGLOrthographicCameras,
FoVPerspectiveCameras,
FoVOrthographicCameras,
):
for scaled_depth_input in (True, False):
if scaled_depth_input:
xy_depth_ = xyz_proj
@@ -459,6 +539,56 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
)
self.assertTrue(torch.allclose(xyz_unproj, matching_xyz, atol=1e-4))
def test_project_points_screen(self, batch_size=50, num_points=100):
"""
Checks that an unprojection of a randomly projected point cloud
stays the same.
"""
for cam_type in (
OpenGLOrthographicCameras,
OpenGLPerspectiveCameras,
SfMOrthographicCameras,
SfMPerspectiveCameras,
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
):
# init the cameras
cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size)
# xyz - the ground truth point cloud
xyz = torch.randn(batch_size, num_points, 3) * 0.3
# image size
image_size = torch.randint(low=2, high=64, size=(batch_size, 2))
# project points
xyz_project_ndc = cameras.transform_points(xyz)
xyz_project_screen = cameras.transform_points_screen(xyz, image_size)
# naive
xyz_project_screen_naive = ndc_to_screen_points_naive(
xyz_project_ndc, image_size
)
self.assertClose(xyz_project_screen, xyz_project_screen_naive)
def test_equiv_project_points(self, batch_size=50, num_points=100):
"""
Checks that NDC and screen cameras project points to ndc correctly.
Applies only to OrthographicCameras and PerspectiveCameras.
"""
for cam_type in (OrthographicCameras, PerspectiveCameras):
# init the cameras
(
ndc_cameras,
screen_cameras,
) = TestCamerasCommon.init_equiv_cameras_ndc_screen(cam_type, batch_size)
# xyz - the ground truth point cloud
xyz = torch.randn(batch_size, num_points, 3) * 0.3
# project points
xyz_ndc_cam = ndc_cameras.transform_points(xyz)
xyz_screen_cam = screen_cameras.transform_points(xyz)
self.assertClose(xyz_ndc_cam, xyz_screen_cam, atol=1e-6)
def test_clone(self, batch_size: int = 10):
"""
Checks the clone function of the cameras.
@@ -468,6 +598,10 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
OpenGLPerspectiveCameras,
OpenGLOrthographicCameras,
SfMPerspectiveCameras,
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
):
cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size)
cameras = cameras.to(torch.device("cpu"))
@@ -483,11 +617,16 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
self.assertTrue(val == val_clone)
class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
############################################################
# FoVPerspective Camera #
############################################################
class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase):
def test_perspective(self):
far = 10.0
near = 1.0
cameras = OpenGLPerspectiveCameras(znear=near, zfar=far, fov=60.0)
cameras = FoVPerspectiveCameras(znear=near, zfar=far, fov=60.0)
P = cameras.get_projection_transform()
# vertices are at the far clipping plane so z gets mapped to 1.
vertices = torch.tensor([1, 2, far], dtype=torch.float32)
@@ -512,7 +651,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertClose(v1.squeeze(), projected_verts)
def test_perspective_kwargs(self):
cameras = OpenGLPerspectiveCameras(znear=5.0, zfar=100.0, fov=0.0)
cameras = FoVPerspectiveCameras(znear=5.0, zfar=100.0, fov=0.0)
# Override defaults by passing in values to get_projection_transform
far = 10.0
P = cameras.get_projection_transform(znear=1.0, zfar=far, fov=60.0)
@@ -528,7 +667,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
far = torch.tensor([10.0, 20.0], dtype=torch.float32)
near = 1.0
fov = torch.tensor(60.0)
cameras = OpenGLPerspectiveCameras(znear=near, zfar=far, fov=fov)
cameras = FoVPerspectiveCameras(znear=near, zfar=far, fov=fov)
P = cameras.get_projection_transform()
vertices = torch.tensor([1, 2, 10], dtype=torch.float32)
z1 = 1.0 # vertices at far clipping plane so z = 1.0
@@ -550,7 +689,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
far = torch.tensor([10.0])
near = 1.0
fov = torch.tensor(60.0, requires_grad=True)
cameras = OpenGLPerspectiveCameras(znear=near, zfar=far, fov=fov)
cameras = FoVPerspectiveCameras(znear=near, zfar=far, fov=fov)
P = cameras.get_projection_transform()
vertices = torch.tensor([1, 2, 10], dtype=torch.float32)
vertices_batch = vertices[None, None, :]
@@ -566,7 +705,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
def test_camera_class_init(self):
device = torch.device("cuda:0")
cam = OpenGLPerspectiveCameras(znear=10.0, zfar=(100.0, 200.0))
cam = FoVPerspectiveCameras(znear=10.0, zfar=(100.0, 200.0))
# Check broadcasting
self.assertTrue(cam.znear.shape == (2,))
@@ -585,7 +724,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertTrue(new_cam.device == device)
def test_get_full_transform(self):
cam = OpenGLPerspectiveCameras()
cam = FoVPerspectiveCameras()
T = torch.tensor([0.0, 0.0, 1.0]).view(1, -1)
R = look_at_rotation(T)
P = cam.get_full_projection_transform(R=R, T=T)
@@ -597,7 +736,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
# Check transform_points methods works with default settings for
# RT and P
far = 10.0
cam = OpenGLPerspectiveCameras(znear=1.0, zfar=far, fov=60.0)
cam = FoVPerspectiveCameras(znear=1.0, zfar=far, fov=60.0)
points = torch.tensor([1, 2, far], dtype=torch.float32)
points = points.view(1, 1, 3).expand(5, 10, -1)
projected_points = torch.tensor(
@@ -608,11 +747,16 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertClose(new_points, projected_points)
class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
############################################################
# FoVOrthographic Camera #
############################################################
class TestFoVOrthographicProjection(TestCaseMixin, unittest.TestCase):
def test_orthographic(self):
far = 10.0
near = 1.0
cameras = OpenGLOrthographicCameras(znear=near, zfar=far)
cameras = FoVOrthographicCameras(znear=near, zfar=far)
P = cameras.get_projection_transform()
vertices = torch.tensor([1, 2, far], dtype=torch.float32)
@@ -637,7 +781,7 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
# applying the scale puts the z coordinate at the far clipping plane
# so the z is mapped to 1.0
projected_verts = torch.tensor([2, 1, 1], dtype=torch.float32)
cameras = OpenGLOrthographicCameras(znear=1.0, zfar=10.0, scale_xyz=scale)
cameras = FoVOrthographicCameras(znear=1.0, zfar=10.0, scale_xyz=scale)
P = cameras.get_projection_transform()
v1 = P.transform_points(vertices)
v2 = orthographic_project_naive(vertices, scale)
@@ -645,7 +789,7 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
self.assertClose(v1, projected_verts[None, None])
def test_orthographic_kwargs(self):
cameras = OpenGLOrthographicCameras(znear=5.0, zfar=100.0)
cameras = FoVOrthographicCameras(znear=5.0, zfar=100.0)
far = 10.0
P = cameras.get_projection_transform(znear=1.0, zfar=far)
vertices = torch.tensor([1, 2, far], dtype=torch.float32)
@@ -657,7 +801,7 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
def test_orthographic_mixed_inputs_broadcast(self):
far = torch.tensor([10.0, 20.0])
near = 1.0
cameras = OpenGLOrthographicCameras(znear=near, zfar=far)
cameras = FoVOrthographicCameras(znear=near, zfar=far)
P = cameras.get_projection_transform()
vertices = torch.tensor([1.0, 2.0, 10.0], dtype=torch.float32)
z2 = 1.0 / (20.0 - 1.0) * 10.0 + -1.0 / (20.0 - 1.0)
@@ -674,7 +818,7 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
far = torch.tensor([10.0])
near = 1.0
scale = torch.tensor([[1.0, 1.0, 1.0]], requires_grad=True)
cameras = OpenGLOrthographicCameras(znear=near, zfar=far, scale_xyz=scale)
cameras = FoVOrthographicCameras(znear=near, zfar=far, scale_xyz=scale)
P = cameras.get_projection_transform()
vertices = torch.tensor([1.0, 2.0, 10.0], dtype=torch.float32)
vertices_batch = vertices[None, None, :]
@@ -694,9 +838,14 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
self.assertClose(scale_grad, grad_scale)
class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase):
############################################################
# Orthographic Camera #
############################################################
class TestOrthographicProjection(TestCaseMixin, unittest.TestCase):
def test_orthographic(self):
cameras = SfMOrthographicCameras()
cameras = OrthographicCameras()
P = cameras.get_projection_transform()
vertices = torch.randn([3, 4, 3], dtype=torch.float32)
@@ -711,9 +860,7 @@ class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase):
focal_length_x = 10.0
focal_length_y = 15.0
cameras = SfMOrthographicCameras(
focal_length=((focal_length_x, focal_length_y),)
)
cameras = OrthographicCameras(focal_length=((focal_length_x, focal_length_y),))
P = cameras.get_projection_transform()
vertices = torch.randn([3, 4, 3], dtype=torch.float32)
@@ -730,9 +877,7 @@ class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase):
self.assertClose(v1, projected_verts)
def test_orthographic_kwargs(self):
cameras = SfMOrthographicCameras(
focal_length=5.0, principal_point=((2.5, 2.5),)
)
cameras = OrthographicCameras(focal_length=5.0, principal_point=((2.5, 2.5),))
P = cameras.get_projection_transform(
focal_length=2.0, principal_point=((2.5, 3.5),)
)
@@ -745,9 +890,14 @@ class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase):
self.assertClose(v1, projected_verts)
class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase):
############################################################
# Perspective Camera #
############################################################
class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
def test_perspective(self):
cameras = SfMPerspectiveCameras()
cameras = PerspectiveCameras()
P = cameras.get_projection_transform()
vertices = torch.randn([3, 4, 3], dtype=torch.float32)
@@ -761,7 +911,7 @@ class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase):
p0x = 15.0
p0y = 30.0
cameras = SfMPerspectiveCameras(
cameras = PerspectiveCameras(
focal_length=((focal_length_x, focal_length_y),),
principal_point=((p0x, p0y),),
)
@@ -777,7 +927,7 @@ class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase):
self.assertClose(v3[..., :2], v2[..., :2])
def test_perspective_kwargs(self):
cameras = SfMPerspectiveCameras(focal_length=5.0, principal_point=((2.5, 2.5),))
cameras = PerspectiveCameras(focal_length=5.0, principal_point=((2.5, 2.5),))
P = cameras.get_projection_transform(
focal_length=2.0, principal_point=((2.5, 3.5),)
)