Join cameras as batch

Summary:
Function to join a list of cameras objects into a single batched object.

FB: In the next diff I will remove the `concatenate_cameras` function in implicitron and update the callsites.

Reviewed By: nikhilaravi

Differential Revision: D33198209

fbshipit-source-id: 0c9f5f5df498a0def9dba756c984e6a946618158
This commit is contained in:
Jeremy Reizenstein
2022-01-21 05:25:23 -08:00
committed by Facebook GitHub Bot
parent 9e2bc3a17f
commit 39bb2ce063
5 changed files with 187 additions and 9 deletions

View File

@@ -250,7 +250,4 @@ class TestPixels(TestCaseMixin, unittest.TestCase):
],
dim=1,
)
print(wanted)
print(camera_points[batch_idx])
self.assertClose(camera_points[batch_idx], wanted)

View File

@@ -36,6 +36,7 @@ import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import (
CamerasBase,
FoVOrthographicCameras,
@@ -688,6 +689,99 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
else:
self.assertTrue(val == val_clone)
def test_join_cameras_as_batch_errors(self):
cam0 = PerspectiveCameras(device="cuda:0")
cam1 = OrthographicCameras(device="cuda:0")
# Cameras not of the same type
with self.assertRaisesRegex(ValueError, "same type"):
join_cameras_as_batch([cam0, cam1])
cam2 = OrthographicCameras(device="cpu")
# Cameras not on the same device
with self.assertRaisesRegex(ValueError, "same device"):
join_cameras_as_batch([cam1, cam2])
cam3 = OrthographicCameras(in_ndc=False, device="cuda:0")
# Different coordinate systems -- all should be in ndc or in screen
with self.assertRaisesRegex(
ValueError, "Attribute _in_ndc is not constant across inputs"
):
join_cameras_as_batch([cam1, cam3])
def join_cameras_as_batch_fov(self, camera_cls):
R0 = torch.randn((6, 3, 3))
R1 = torch.randn((3, 3, 3))
cam0 = camera_cls(znear=10.0, zfar=100.0, R=R0, device="cuda:0")
cam1 = camera_cls(znear=10.0, zfar=200.0, R=R1, device="cuda:0")
cam_batch = join_cameras_as_batch([cam0, cam1])
self.assertEqual(cam_batch._N, cam0._N + cam1._N)
self.assertEqual(cam_batch.device, cam0.device)
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0).to(device="cuda:0"))
def join_cameras_as_batch(self, camera_cls):
R0 = torch.randn((6, 3, 3))
R1 = torch.randn((3, 3, 3))
p0 = torch.randn((6, 2, 1))
p1 = torch.randn((3, 2, 1))
f0 = 5.0
f1 = torch.randn(3, 2)
f2 = torch.randn(3, 1)
cam0 = camera_cls(
R=R0,
focal_length=f0,
principal_point=p0,
)
cam1 = camera_cls(
R=R1,
focal_length=f0,
principal_point=p1,
)
cam2 = camera_cls(
R=R1,
focal_length=f1,
principal_point=p1,
)
cam3 = camera_cls(
R=R1,
focal_length=f2,
principal_point=p1,
)
cam_batch = join_cameras_as_batch([cam0, cam1])
self.assertEqual(cam_batch._N, cam0._N + cam1._N)
self.assertEqual(cam_batch.device, cam0.device)
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
self.assertClose(cam_batch.principal_point, torch.cat((p0, p1), dim=0))
self.assertEqual(cam_batch._in_ndc, cam0._in_ndc)
# Test one broadcasted value and one fixed value
# Focal length as (N,) in one camera and (N, 2) in the other
cam_batch = join_cameras_as_batch([cam0, cam2])
self.assertEqual(cam_batch._N, cam0._N + cam2._N)
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
self.assertClose(
cam_batch.focal_length,
torch.cat([torch.tensor([[f0, f0]]).expand(6, -1), f1], dim=0),
)
# Focal length as (N, 1) in one camera and (N, 2) in the other
cam_batch = join_cameras_as_batch([cam2, cam3])
self.assertClose(
cam_batch.focal_length,
torch.cat([f1, f2.expand(-1, 2)], dim=0),
)
def test_join_batch_perspective(self):
self.join_cameras_as_batch_fov(FoVPerspectiveCameras)
self.join_cameras_as_batch(PerspectiveCameras)
def test_join_batch_orthographic(self):
self.join_cameras_as_batch_fov(FoVOrthographicCameras)
self.join_cameras_as_batch(OrthographicCameras)
############################################################
# FoVPerspective Camera #
@@ -1055,7 +1149,7 @@ class TestOrthographicProjection(TestCaseMixin, unittest.TestCase):
index = torch.tensor([1, 3, 5], dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
@@ -1131,7 +1225,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
index = torch.tensor([1, 3, 5], dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])