Join cameras as batch

Summary:
Function to join a list of cameras objects into a single batched object.

FB: In the next diff I will remove the `concatenate_cameras` function in implicitron and update the callsites.

Reviewed By: nikhilaravi

Differential Revision: D33198209

fbshipit-source-id: 0c9f5f5df498a0def9dba756c984e6a946618158
This commit is contained in:
Jeremy Reizenstein 2022-01-21 05:25:23 -08:00 committed by Facebook GitHub Bot
parent 9e2bc3a17f
commit 39bb2ce063
5 changed files with 187 additions and 9 deletions

View File

@ -10,7 +10,7 @@ from .blending import (
sigmoid_alpha_blend,
softmax_rgb_blend,
)
from .camera_utils import rotate_on_spot
from .camera_utils import join_cameras_as_batch, rotate_on_spot
from .cameras import OpenGLOrthographicCameras # deprecated
from .cameras import OpenGLPerspectiveCameras # deprecated
from .cameras import SfMOrthographicCameras # deprecated
@ -29,6 +29,7 @@ from .implicit import (
AbsorptionOnlyRaymarcher,
EmissionAbsorptionRaymarcher,
GridRaysampler,
HarmonicEmbedding,
ImplicitRenderer,
MonteCarloRaysampler,
NDCGridRaysampler,
@ -37,7 +38,6 @@ from .implicit import (
VolumeSampler,
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
HarmonicEmbedding,
)
from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular
from .materials import Materials

View File

@ -4,11 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
from typing import Sequence, Tuple
import torch
from pytorch3d.transforms import Transform3d
from .cameras import CamerasBase
def camera_to_eye_at_up(
world_to_view_transform: Transform3d,
@ -141,3 +143,65 @@ def rotate_on_spot(
new_T = torch.matmul(new_R.transpose(1, 2), old_RT)[:, :, 0]
return new_R, new_T
def join_cameras_as_batch(cameras_list: Sequence[CamerasBase]) -> CamerasBase:
"""
Create a batched cameras object by concatenating a list of input
cameras objects. All the tensor attributes will be joined along
the batch dimension.
Args:
cameras_list: List of camera classes all of the same type and
on the same device. Each represents one or more cameras.
Returns:
cameras: single batched cameras object of the same
type as all the objects in the input list.
"""
# Get the type and fields to join from the first camera in the batch
c0 = cameras_list[0]
fields = c0._FIELDS
shared_fields = c0._SHARED_FIELDS
if not all(isinstance(c, CamerasBase) for c in cameras_list):
raise ValueError("cameras in cameras_list must inherit from CamerasBase")
if not all(type(c) is type(c0) for c in cameras_list[1:]):
raise ValueError("All cameras must be of the same type")
if not all(c.device == c0.device for c in cameras_list[1:]):
raise ValueError("All cameras in the batch must be on the same device")
# Concat the fields to make a batched tensor
kwargs = {}
kwargs["device"] = c0.device
for field in fields:
field_not_none = [(getattr(c, field) is not None) for c in cameras_list]
if not any(field_not_none):
continue
if not all(field_not_none):
raise ValueError(f"Attribute {field} is inconsistently present")
attrs_list = [getattr(c, field) for c in cameras_list]
if field in shared_fields:
# Only needs to be set once
if not all(a == attrs_list[0] for a in attrs_list):
raise ValueError(f"Attribute {field} is not constant across inputs")
# e.g. "in_ndc" is set as attribute "_in_ndc" on the class
# but provided as "in_ndc" in the input args
if field.startswith("_"):
field = field[1:]
kwargs[field] = attrs_list[0]
elif isinstance(attrs_list[0], torch.Tensor):
# In the init, all inputs will be converted to
# batched tensors before set as attributes
# Join as a tensor along the batch dimension
kwargs[field] = torch.cat(attrs_list, dim=0)
else:
raise ValueError(f"Field {field} type is not supported for batching")
return c0.__class__(**kwargs)

View File

@ -77,7 +77,12 @@ class CamerasBase(TensorProperties):
# Used in __getitem__ to index the relevant fields
# When creating a new camera, this should be set in the __init__
_FIELDS: Tuple = ()
_FIELDS: Tuple[str, ...] = ()
# Names of fields which are a constant property of the whole batch, rather
# than themselves a batch of data.
# When joining objects into a batch, they will have to agree.
_SHARED_FIELDS: Tuple[str, ...] = ()
def get_projection_transform(self):
"""
@ -499,6 +504,8 @@ class FoVPerspectiveCameras(CamerasBase):
"degrees",
)
_SHARED_FIELDS = ("degrees",)
def __init__(
self,
znear=1.0,
@ -997,6 +1004,8 @@ class PerspectiveCameras(CamerasBase):
"image_size",
)
_SHARED_FIELDS = ("_in_ndc",)
def __init__(
self,
focal_length=1.0,
@ -1047,6 +1056,12 @@ class PerspectiveCameras(CamerasBase):
else:
self.image_size = None
# When focal length is provided as one value, expand to
# create (N, 2) shape tensor
if self.focal_length.ndim == 1: # (N,)
self.focal_length = self.focal_length[:, None] # (N, 1)
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the projection matrix using the
@ -1227,6 +1242,8 @@ class OrthographicCameras(CamerasBase):
"image_size",
)
_SHARED_FIELDS = ("_in_ndc",)
def __init__(
self,
focal_length=1.0,
@ -1276,6 +1293,12 @@ class OrthographicCameras(CamerasBase):
else:
self.image_size = None
# When focal length is provided as one value, expand to
# create (N, 2) shape tensor
if self.focal_length.ndim == 1: # (N,)
self.focal_length = self.focal_length[:, None] # (N, 1)
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the projection matrix using

View File

@ -250,7 +250,4 @@ class TestPixels(TestCaseMixin, unittest.TestCase):
],
dim=1,
)
print(wanted)
print(camera_points[batch_idx])
self.assertClose(camera_points[batch_idx], wanted)

View File

@ -36,6 +36,7 @@ import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import (
CamerasBase,
FoVOrthographicCameras,
@ -688,6 +689,99 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
else:
self.assertTrue(val == val_clone)
def test_join_cameras_as_batch_errors(self):
cam0 = PerspectiveCameras(device="cuda:0")
cam1 = OrthographicCameras(device="cuda:0")
# Cameras not of the same type
with self.assertRaisesRegex(ValueError, "same type"):
join_cameras_as_batch([cam0, cam1])
cam2 = OrthographicCameras(device="cpu")
# Cameras not on the same device
with self.assertRaisesRegex(ValueError, "same device"):
join_cameras_as_batch([cam1, cam2])
cam3 = OrthographicCameras(in_ndc=False, device="cuda:0")
# Different coordinate systems -- all should be in ndc or in screen
with self.assertRaisesRegex(
ValueError, "Attribute _in_ndc is not constant across inputs"
):
join_cameras_as_batch([cam1, cam3])
def join_cameras_as_batch_fov(self, camera_cls):
R0 = torch.randn((6, 3, 3))
R1 = torch.randn((3, 3, 3))
cam0 = camera_cls(znear=10.0, zfar=100.0, R=R0, device="cuda:0")
cam1 = camera_cls(znear=10.0, zfar=200.0, R=R1, device="cuda:0")
cam_batch = join_cameras_as_batch([cam0, cam1])
self.assertEqual(cam_batch._N, cam0._N + cam1._N)
self.assertEqual(cam_batch.device, cam0.device)
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0).to(device="cuda:0"))
def join_cameras_as_batch(self, camera_cls):
R0 = torch.randn((6, 3, 3))
R1 = torch.randn((3, 3, 3))
p0 = torch.randn((6, 2, 1))
p1 = torch.randn((3, 2, 1))
f0 = 5.0
f1 = torch.randn(3, 2)
f2 = torch.randn(3, 1)
cam0 = camera_cls(
R=R0,
focal_length=f0,
principal_point=p0,
)
cam1 = camera_cls(
R=R1,
focal_length=f0,
principal_point=p1,
)
cam2 = camera_cls(
R=R1,
focal_length=f1,
principal_point=p1,
)
cam3 = camera_cls(
R=R1,
focal_length=f2,
principal_point=p1,
)
cam_batch = join_cameras_as_batch([cam0, cam1])
self.assertEqual(cam_batch._N, cam0._N + cam1._N)
self.assertEqual(cam_batch.device, cam0.device)
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
self.assertClose(cam_batch.principal_point, torch.cat((p0, p1), dim=0))
self.assertEqual(cam_batch._in_ndc, cam0._in_ndc)
# Test one broadcasted value and one fixed value
# Focal length as (N,) in one camera and (N, 2) in the other
cam_batch = join_cameras_as_batch([cam0, cam2])
self.assertEqual(cam_batch._N, cam0._N + cam2._N)
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
self.assertClose(
cam_batch.focal_length,
torch.cat([torch.tensor([[f0, f0]]).expand(6, -1), f1], dim=0),
)
# Focal length as (N, 1) in one camera and (N, 2) in the other
cam_batch = join_cameras_as_batch([cam2, cam3])
self.assertClose(
cam_batch.focal_length,
torch.cat([f1, f2.expand(-1, 2)], dim=0),
)
def test_join_batch_perspective(self):
self.join_cameras_as_batch_fov(FoVPerspectiveCameras)
self.join_cameras_as_batch(PerspectiveCameras)
def test_join_batch_orthographic(self):
self.join_cameras_as_batch_fov(FoVOrthographicCameras)
self.join_cameras_as_batch(OrthographicCameras)
############################################################
# FoVPerspective Camera #
@ -1055,7 +1149,7 @@ class TestOrthographicProjection(TestCaseMixin, unittest.TestCase):
index = torch.tensor([1, 3, 5], dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
@ -1131,7 +1225,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
index = torch.tensor([1, 3, 5], dtype=torch.int64)
c135 = cam[index]
self.assertEqual(len(c135), 3)
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])