mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Join cameras as batch
Summary: Function to join a list of cameras objects into a single batched object. FB: In the next diff I will remove the `concatenate_cameras` function in implicitron and update the callsites. Reviewed By: nikhilaravi Differential Revision: D33198209 fbshipit-source-id: 0c9f5f5df498a0def9dba756c984e6a946618158
This commit is contained in:
parent
9e2bc3a17f
commit
39bb2ce063
@ -10,7 +10,7 @@ from .blending import (
|
||||
sigmoid_alpha_blend,
|
||||
softmax_rgb_blend,
|
||||
)
|
||||
from .camera_utils import rotate_on_spot
|
||||
from .camera_utils import join_cameras_as_batch, rotate_on_spot
|
||||
from .cameras import OpenGLOrthographicCameras # deprecated
|
||||
from .cameras import OpenGLPerspectiveCameras # deprecated
|
||||
from .cameras import SfMOrthographicCameras # deprecated
|
||||
@ -29,6 +29,7 @@ from .implicit import (
|
||||
AbsorptionOnlyRaymarcher,
|
||||
EmissionAbsorptionRaymarcher,
|
||||
GridRaysampler,
|
||||
HarmonicEmbedding,
|
||||
ImplicitRenderer,
|
||||
MonteCarloRaysampler,
|
||||
NDCGridRaysampler,
|
||||
@ -37,7 +38,6 @@ from .implicit import (
|
||||
VolumeSampler,
|
||||
ray_bundle_to_ray_points,
|
||||
ray_bundle_variables_to_ray_points,
|
||||
HarmonicEmbedding,
|
||||
)
|
||||
from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular
|
||||
from .materials import Materials
|
||||
|
@ -4,11 +4,13 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.transforms import Transform3d
|
||||
|
||||
from .cameras import CamerasBase
|
||||
|
||||
|
||||
def camera_to_eye_at_up(
|
||||
world_to_view_transform: Transform3d,
|
||||
@ -141,3 +143,65 @@ def rotate_on_spot(
|
||||
new_T = torch.matmul(new_R.transpose(1, 2), old_RT)[:, :, 0]
|
||||
|
||||
return new_R, new_T
|
||||
|
||||
|
||||
def join_cameras_as_batch(cameras_list: Sequence[CamerasBase]) -> CamerasBase:
|
||||
"""
|
||||
Create a batched cameras object by concatenating a list of input
|
||||
cameras objects. All the tensor attributes will be joined along
|
||||
the batch dimension.
|
||||
|
||||
Args:
|
||||
cameras_list: List of camera classes all of the same type and
|
||||
on the same device. Each represents one or more cameras.
|
||||
Returns:
|
||||
cameras: single batched cameras object of the same
|
||||
type as all the objects in the input list.
|
||||
"""
|
||||
# Get the type and fields to join from the first camera in the batch
|
||||
c0 = cameras_list[0]
|
||||
fields = c0._FIELDS
|
||||
shared_fields = c0._SHARED_FIELDS
|
||||
|
||||
if not all(isinstance(c, CamerasBase) for c in cameras_list):
|
||||
raise ValueError("cameras in cameras_list must inherit from CamerasBase")
|
||||
|
||||
if not all(type(c) is type(c0) for c in cameras_list[1:]):
|
||||
raise ValueError("All cameras must be of the same type")
|
||||
|
||||
if not all(c.device == c0.device for c in cameras_list[1:]):
|
||||
raise ValueError("All cameras in the batch must be on the same device")
|
||||
|
||||
# Concat the fields to make a batched tensor
|
||||
kwargs = {}
|
||||
kwargs["device"] = c0.device
|
||||
|
||||
for field in fields:
|
||||
field_not_none = [(getattr(c, field) is not None) for c in cameras_list]
|
||||
if not any(field_not_none):
|
||||
continue
|
||||
if not all(field_not_none):
|
||||
raise ValueError(f"Attribute {field} is inconsistently present")
|
||||
|
||||
attrs_list = [getattr(c, field) for c in cameras_list]
|
||||
|
||||
if field in shared_fields:
|
||||
# Only needs to be set once
|
||||
if not all(a == attrs_list[0] for a in attrs_list):
|
||||
raise ValueError(f"Attribute {field} is not constant across inputs")
|
||||
|
||||
# e.g. "in_ndc" is set as attribute "_in_ndc" on the class
|
||||
# but provided as "in_ndc" in the input args
|
||||
if field.startswith("_"):
|
||||
field = field[1:]
|
||||
|
||||
kwargs[field] = attrs_list[0]
|
||||
elif isinstance(attrs_list[0], torch.Tensor):
|
||||
# In the init, all inputs will be converted to
|
||||
# batched tensors before set as attributes
|
||||
# Join as a tensor along the batch dimension
|
||||
kwargs[field] = torch.cat(attrs_list, dim=0)
|
||||
else:
|
||||
raise ValueError(f"Field {field} type is not supported for batching")
|
||||
|
||||
return c0.__class__(**kwargs)
|
||||
|
@ -77,7 +77,12 @@ class CamerasBase(TensorProperties):
|
||||
|
||||
# Used in __getitem__ to index the relevant fields
|
||||
# When creating a new camera, this should be set in the __init__
|
||||
_FIELDS: Tuple = ()
|
||||
_FIELDS: Tuple[str, ...] = ()
|
||||
|
||||
# Names of fields which are a constant property of the whole batch, rather
|
||||
# than themselves a batch of data.
|
||||
# When joining objects into a batch, they will have to agree.
|
||||
_SHARED_FIELDS: Tuple[str, ...] = ()
|
||||
|
||||
def get_projection_transform(self):
|
||||
"""
|
||||
@ -499,6 +504,8 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
"degrees",
|
||||
)
|
||||
|
||||
_SHARED_FIELDS = ("degrees",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
znear=1.0,
|
||||
@ -997,6 +1004,8 @@ class PerspectiveCameras(CamerasBase):
|
||||
"image_size",
|
||||
)
|
||||
|
||||
_SHARED_FIELDS = ("_in_ndc",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
focal_length=1.0,
|
||||
@ -1047,6 +1056,12 @@ class PerspectiveCameras(CamerasBase):
|
||||
else:
|
||||
self.image_size = None
|
||||
|
||||
# When focal length is provided as one value, expand to
|
||||
# create (N, 2) shape tensor
|
||||
if self.focal_length.ndim == 1: # (N,)
|
||||
self.focal_length = self.focal_length[:, None] # (N, 1)
|
||||
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Calculate the projection matrix using the
|
||||
@ -1227,6 +1242,8 @@ class OrthographicCameras(CamerasBase):
|
||||
"image_size",
|
||||
)
|
||||
|
||||
_SHARED_FIELDS = ("_in_ndc",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
focal_length=1.0,
|
||||
@ -1276,6 +1293,12 @@ class OrthographicCameras(CamerasBase):
|
||||
else:
|
||||
self.image_size = None
|
||||
|
||||
# When focal length is provided as one value, expand to
|
||||
# create (N, 2) shape tensor
|
||||
if self.focal_length.ndim == 1: # (N,)
|
||||
self.focal_length = self.focal_length[:, None] # (N, 1)
|
||||
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Calculate the projection matrix using
|
||||
|
@ -250,7 +250,4 @@ class TestPixels(TestCaseMixin, unittest.TestCase):
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
print(wanted)
|
||||
print(camera_points[batch_idx])
|
||||
self.assertClose(camera_points[batch_idx], wanted)
|
||||
|
@ -36,6 +36,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import (
|
||||
CamerasBase,
|
||||
FoVOrthographicCameras,
|
||||
@ -688,6 +689,99 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
|
||||
else:
|
||||
self.assertTrue(val == val_clone)
|
||||
|
||||
def test_join_cameras_as_batch_errors(self):
|
||||
cam0 = PerspectiveCameras(device="cuda:0")
|
||||
cam1 = OrthographicCameras(device="cuda:0")
|
||||
|
||||
# Cameras not of the same type
|
||||
with self.assertRaisesRegex(ValueError, "same type"):
|
||||
join_cameras_as_batch([cam0, cam1])
|
||||
|
||||
cam2 = OrthographicCameras(device="cpu")
|
||||
# Cameras not on the same device
|
||||
with self.assertRaisesRegex(ValueError, "same device"):
|
||||
join_cameras_as_batch([cam1, cam2])
|
||||
|
||||
cam3 = OrthographicCameras(in_ndc=False, device="cuda:0")
|
||||
# Different coordinate systems -- all should be in ndc or in screen
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Attribute _in_ndc is not constant across inputs"
|
||||
):
|
||||
join_cameras_as_batch([cam1, cam3])
|
||||
|
||||
def join_cameras_as_batch_fov(self, camera_cls):
|
||||
R0 = torch.randn((6, 3, 3))
|
||||
R1 = torch.randn((3, 3, 3))
|
||||
cam0 = camera_cls(znear=10.0, zfar=100.0, R=R0, device="cuda:0")
|
||||
cam1 = camera_cls(znear=10.0, zfar=200.0, R=R1, device="cuda:0")
|
||||
|
||||
cam_batch = join_cameras_as_batch([cam0, cam1])
|
||||
|
||||
self.assertEqual(cam_batch._N, cam0._N + cam1._N)
|
||||
self.assertEqual(cam_batch.device, cam0.device)
|
||||
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0).to(device="cuda:0"))
|
||||
|
||||
def join_cameras_as_batch(self, camera_cls):
|
||||
R0 = torch.randn((6, 3, 3))
|
||||
R1 = torch.randn((3, 3, 3))
|
||||
p0 = torch.randn((6, 2, 1))
|
||||
p1 = torch.randn((3, 2, 1))
|
||||
f0 = 5.0
|
||||
f1 = torch.randn(3, 2)
|
||||
f2 = torch.randn(3, 1)
|
||||
cam0 = camera_cls(
|
||||
R=R0,
|
||||
focal_length=f0,
|
||||
principal_point=p0,
|
||||
)
|
||||
cam1 = camera_cls(
|
||||
R=R1,
|
||||
focal_length=f0,
|
||||
principal_point=p1,
|
||||
)
|
||||
cam2 = camera_cls(
|
||||
R=R1,
|
||||
focal_length=f1,
|
||||
principal_point=p1,
|
||||
)
|
||||
cam3 = camera_cls(
|
||||
R=R1,
|
||||
focal_length=f2,
|
||||
principal_point=p1,
|
||||
)
|
||||
cam_batch = join_cameras_as_batch([cam0, cam1])
|
||||
|
||||
self.assertEqual(cam_batch._N, cam0._N + cam1._N)
|
||||
self.assertEqual(cam_batch.device, cam0.device)
|
||||
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
|
||||
self.assertClose(cam_batch.principal_point, torch.cat((p0, p1), dim=0))
|
||||
self.assertEqual(cam_batch._in_ndc, cam0._in_ndc)
|
||||
|
||||
# Test one broadcasted value and one fixed value
|
||||
# Focal length as (N,) in one camera and (N, 2) in the other
|
||||
cam_batch = join_cameras_as_batch([cam0, cam2])
|
||||
self.assertEqual(cam_batch._N, cam0._N + cam2._N)
|
||||
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
|
||||
self.assertClose(
|
||||
cam_batch.focal_length,
|
||||
torch.cat([torch.tensor([[f0, f0]]).expand(6, -1), f1], dim=0),
|
||||
)
|
||||
|
||||
# Focal length as (N, 1) in one camera and (N, 2) in the other
|
||||
cam_batch = join_cameras_as_batch([cam2, cam3])
|
||||
self.assertClose(
|
||||
cam_batch.focal_length,
|
||||
torch.cat([f1, f2.expand(-1, 2)], dim=0),
|
||||
)
|
||||
|
||||
def test_join_batch_perspective(self):
|
||||
self.join_cameras_as_batch_fov(FoVPerspectiveCameras)
|
||||
self.join_cameras_as_batch(PerspectiveCameras)
|
||||
|
||||
def test_join_batch_orthographic(self):
|
||||
self.join_cameras_as_batch_fov(FoVOrthographicCameras)
|
||||
self.join_cameras_as_batch(OrthographicCameras)
|
||||
|
||||
|
||||
############################################################
|
||||
# FoVPerspective Camera #
|
||||
@ -1055,7 +1149,7 @@ class TestOrthographicProjection(TestCaseMixin, unittest.TestCase):
|
||||
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
||||
c135 = cam[index]
|
||||
self.assertEqual(len(c135), 3)
|
||||
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
|
||||
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
|
||||
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
||||
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
|
||||
|
||||
@ -1131,7 +1225,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
|
||||
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
||||
c135 = cam[index]
|
||||
self.assertEqual(len(c135), 3)
|
||||
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
|
||||
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
|
||||
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
||||
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user