mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Join cameras as batch
Summary: Function to join a list of cameras objects into a single batched object. FB: In the next diff I will remove the `concatenate_cameras` function in implicitron and update the callsites. Reviewed By: nikhilaravi Differential Revision: D33198209 fbshipit-source-id: 0c9f5f5df498a0def9dba756c984e6a946618158
This commit is contained in:
parent
9e2bc3a17f
commit
39bb2ce063
@ -10,7 +10,7 @@ from .blending import (
|
|||||||
sigmoid_alpha_blend,
|
sigmoid_alpha_blend,
|
||||||
softmax_rgb_blend,
|
softmax_rgb_blend,
|
||||||
)
|
)
|
||||||
from .camera_utils import rotate_on_spot
|
from .camera_utils import join_cameras_as_batch, rotate_on_spot
|
||||||
from .cameras import OpenGLOrthographicCameras # deprecated
|
from .cameras import OpenGLOrthographicCameras # deprecated
|
||||||
from .cameras import OpenGLPerspectiveCameras # deprecated
|
from .cameras import OpenGLPerspectiveCameras # deprecated
|
||||||
from .cameras import SfMOrthographicCameras # deprecated
|
from .cameras import SfMOrthographicCameras # deprecated
|
||||||
@ -29,6 +29,7 @@ from .implicit import (
|
|||||||
AbsorptionOnlyRaymarcher,
|
AbsorptionOnlyRaymarcher,
|
||||||
EmissionAbsorptionRaymarcher,
|
EmissionAbsorptionRaymarcher,
|
||||||
GridRaysampler,
|
GridRaysampler,
|
||||||
|
HarmonicEmbedding,
|
||||||
ImplicitRenderer,
|
ImplicitRenderer,
|
||||||
MonteCarloRaysampler,
|
MonteCarloRaysampler,
|
||||||
NDCGridRaysampler,
|
NDCGridRaysampler,
|
||||||
@ -37,7 +38,6 @@ from .implicit import (
|
|||||||
VolumeSampler,
|
VolumeSampler,
|
||||||
ray_bundle_to_ray_points,
|
ray_bundle_to_ray_points,
|
||||||
ray_bundle_variables_to_ray_points,
|
ray_bundle_variables_to_ray_points,
|
||||||
HarmonicEmbedding,
|
|
||||||
)
|
)
|
||||||
from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular
|
from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular
|
||||||
from .materials import Materials
|
from .materials import Materials
|
||||||
|
@ -4,11 +4,13 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.transforms import Transform3d
|
from pytorch3d.transforms import Transform3d
|
||||||
|
|
||||||
|
from .cameras import CamerasBase
|
||||||
|
|
||||||
|
|
||||||
def camera_to_eye_at_up(
|
def camera_to_eye_at_up(
|
||||||
world_to_view_transform: Transform3d,
|
world_to_view_transform: Transform3d,
|
||||||
@ -141,3 +143,65 @@ def rotate_on_spot(
|
|||||||
new_T = torch.matmul(new_R.transpose(1, 2), old_RT)[:, :, 0]
|
new_T = torch.matmul(new_R.transpose(1, 2), old_RT)[:, :, 0]
|
||||||
|
|
||||||
return new_R, new_T
|
return new_R, new_T
|
||||||
|
|
||||||
|
|
||||||
|
def join_cameras_as_batch(cameras_list: Sequence[CamerasBase]) -> CamerasBase:
|
||||||
|
"""
|
||||||
|
Create a batched cameras object by concatenating a list of input
|
||||||
|
cameras objects. All the tensor attributes will be joined along
|
||||||
|
the batch dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cameras_list: List of camera classes all of the same type and
|
||||||
|
on the same device. Each represents one or more cameras.
|
||||||
|
Returns:
|
||||||
|
cameras: single batched cameras object of the same
|
||||||
|
type as all the objects in the input list.
|
||||||
|
"""
|
||||||
|
# Get the type and fields to join from the first camera in the batch
|
||||||
|
c0 = cameras_list[0]
|
||||||
|
fields = c0._FIELDS
|
||||||
|
shared_fields = c0._SHARED_FIELDS
|
||||||
|
|
||||||
|
if not all(isinstance(c, CamerasBase) for c in cameras_list):
|
||||||
|
raise ValueError("cameras in cameras_list must inherit from CamerasBase")
|
||||||
|
|
||||||
|
if not all(type(c) is type(c0) for c in cameras_list[1:]):
|
||||||
|
raise ValueError("All cameras must be of the same type")
|
||||||
|
|
||||||
|
if not all(c.device == c0.device for c in cameras_list[1:]):
|
||||||
|
raise ValueError("All cameras in the batch must be on the same device")
|
||||||
|
|
||||||
|
# Concat the fields to make a batched tensor
|
||||||
|
kwargs = {}
|
||||||
|
kwargs["device"] = c0.device
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
field_not_none = [(getattr(c, field) is not None) for c in cameras_list]
|
||||||
|
if not any(field_not_none):
|
||||||
|
continue
|
||||||
|
if not all(field_not_none):
|
||||||
|
raise ValueError(f"Attribute {field} is inconsistently present")
|
||||||
|
|
||||||
|
attrs_list = [getattr(c, field) for c in cameras_list]
|
||||||
|
|
||||||
|
if field in shared_fields:
|
||||||
|
# Only needs to be set once
|
||||||
|
if not all(a == attrs_list[0] for a in attrs_list):
|
||||||
|
raise ValueError(f"Attribute {field} is not constant across inputs")
|
||||||
|
|
||||||
|
# e.g. "in_ndc" is set as attribute "_in_ndc" on the class
|
||||||
|
# but provided as "in_ndc" in the input args
|
||||||
|
if field.startswith("_"):
|
||||||
|
field = field[1:]
|
||||||
|
|
||||||
|
kwargs[field] = attrs_list[0]
|
||||||
|
elif isinstance(attrs_list[0], torch.Tensor):
|
||||||
|
# In the init, all inputs will be converted to
|
||||||
|
# batched tensors before set as attributes
|
||||||
|
# Join as a tensor along the batch dimension
|
||||||
|
kwargs[field] = torch.cat(attrs_list, dim=0)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Field {field} type is not supported for batching")
|
||||||
|
|
||||||
|
return c0.__class__(**kwargs)
|
||||||
|
@ -77,7 +77,12 @@ class CamerasBase(TensorProperties):
|
|||||||
|
|
||||||
# Used in __getitem__ to index the relevant fields
|
# Used in __getitem__ to index the relevant fields
|
||||||
# When creating a new camera, this should be set in the __init__
|
# When creating a new camera, this should be set in the __init__
|
||||||
_FIELDS: Tuple = ()
|
_FIELDS: Tuple[str, ...] = ()
|
||||||
|
|
||||||
|
# Names of fields which are a constant property of the whole batch, rather
|
||||||
|
# than themselves a batch of data.
|
||||||
|
# When joining objects into a batch, they will have to agree.
|
||||||
|
_SHARED_FIELDS: Tuple[str, ...] = ()
|
||||||
|
|
||||||
def get_projection_transform(self):
|
def get_projection_transform(self):
|
||||||
"""
|
"""
|
||||||
@ -499,6 +504,8 @@ class FoVPerspectiveCameras(CamerasBase):
|
|||||||
"degrees",
|
"degrees",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_SHARED_FIELDS = ("degrees",)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
znear=1.0,
|
znear=1.0,
|
||||||
@ -997,6 +1004,8 @@ class PerspectiveCameras(CamerasBase):
|
|||||||
"image_size",
|
"image_size",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_SHARED_FIELDS = ("_in_ndc",)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
focal_length=1.0,
|
focal_length=1.0,
|
||||||
@ -1047,6 +1056,12 @@ class PerspectiveCameras(CamerasBase):
|
|||||||
else:
|
else:
|
||||||
self.image_size = None
|
self.image_size = None
|
||||||
|
|
||||||
|
# When focal length is provided as one value, expand to
|
||||||
|
# create (N, 2) shape tensor
|
||||||
|
if self.focal_length.ndim == 1: # (N,)
|
||||||
|
self.focal_length = self.focal_length[:, None] # (N, 1)
|
||||||
|
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)
|
||||||
|
|
||||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||||
"""
|
"""
|
||||||
Calculate the projection matrix using the
|
Calculate the projection matrix using the
|
||||||
@ -1227,6 +1242,8 @@ class OrthographicCameras(CamerasBase):
|
|||||||
"image_size",
|
"image_size",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_SHARED_FIELDS = ("_in_ndc",)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
focal_length=1.0,
|
focal_length=1.0,
|
||||||
@ -1276,6 +1293,12 @@ class OrthographicCameras(CamerasBase):
|
|||||||
else:
|
else:
|
||||||
self.image_size = None
|
self.image_size = None
|
||||||
|
|
||||||
|
# When focal length is provided as one value, expand to
|
||||||
|
# create (N, 2) shape tensor
|
||||||
|
if self.focal_length.ndim == 1: # (N,)
|
||||||
|
self.focal_length = self.focal_length[:, None] # (N, 1)
|
||||||
|
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)
|
||||||
|
|
||||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||||
"""
|
"""
|
||||||
Calculate the projection matrix using
|
Calculate the projection matrix using
|
||||||
|
@ -250,7 +250,4 @@ class TestPixels(TestCaseMixin, unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(wanted)
|
|
||||||
print(camera_points[batch_idx])
|
|
||||||
self.assertClose(camera_points[batch_idx], wanted)
|
self.assertClose(camera_points[batch_idx], wanted)
|
||||||
|
@ -36,6 +36,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||||
from pytorch3d.renderer.cameras import (
|
from pytorch3d.renderer.cameras import (
|
||||||
CamerasBase,
|
CamerasBase,
|
||||||
FoVOrthographicCameras,
|
FoVOrthographicCameras,
|
||||||
@ -688,6 +689,99 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertTrue(val == val_clone)
|
self.assertTrue(val == val_clone)
|
||||||
|
|
||||||
|
def test_join_cameras_as_batch_errors(self):
|
||||||
|
cam0 = PerspectiveCameras(device="cuda:0")
|
||||||
|
cam1 = OrthographicCameras(device="cuda:0")
|
||||||
|
|
||||||
|
# Cameras not of the same type
|
||||||
|
with self.assertRaisesRegex(ValueError, "same type"):
|
||||||
|
join_cameras_as_batch([cam0, cam1])
|
||||||
|
|
||||||
|
cam2 = OrthographicCameras(device="cpu")
|
||||||
|
# Cameras not on the same device
|
||||||
|
with self.assertRaisesRegex(ValueError, "same device"):
|
||||||
|
join_cameras_as_batch([cam1, cam2])
|
||||||
|
|
||||||
|
cam3 = OrthographicCameras(in_ndc=False, device="cuda:0")
|
||||||
|
# Different coordinate systems -- all should be in ndc or in screen
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "Attribute _in_ndc is not constant across inputs"
|
||||||
|
):
|
||||||
|
join_cameras_as_batch([cam1, cam3])
|
||||||
|
|
||||||
|
def join_cameras_as_batch_fov(self, camera_cls):
|
||||||
|
R0 = torch.randn((6, 3, 3))
|
||||||
|
R1 = torch.randn((3, 3, 3))
|
||||||
|
cam0 = camera_cls(znear=10.0, zfar=100.0, R=R0, device="cuda:0")
|
||||||
|
cam1 = camera_cls(znear=10.0, zfar=200.0, R=R1, device="cuda:0")
|
||||||
|
|
||||||
|
cam_batch = join_cameras_as_batch([cam0, cam1])
|
||||||
|
|
||||||
|
self.assertEqual(cam_batch._N, cam0._N + cam1._N)
|
||||||
|
self.assertEqual(cam_batch.device, cam0.device)
|
||||||
|
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0).to(device="cuda:0"))
|
||||||
|
|
||||||
|
def join_cameras_as_batch(self, camera_cls):
|
||||||
|
R0 = torch.randn((6, 3, 3))
|
||||||
|
R1 = torch.randn((3, 3, 3))
|
||||||
|
p0 = torch.randn((6, 2, 1))
|
||||||
|
p1 = torch.randn((3, 2, 1))
|
||||||
|
f0 = 5.0
|
||||||
|
f1 = torch.randn(3, 2)
|
||||||
|
f2 = torch.randn(3, 1)
|
||||||
|
cam0 = camera_cls(
|
||||||
|
R=R0,
|
||||||
|
focal_length=f0,
|
||||||
|
principal_point=p0,
|
||||||
|
)
|
||||||
|
cam1 = camera_cls(
|
||||||
|
R=R1,
|
||||||
|
focal_length=f0,
|
||||||
|
principal_point=p1,
|
||||||
|
)
|
||||||
|
cam2 = camera_cls(
|
||||||
|
R=R1,
|
||||||
|
focal_length=f1,
|
||||||
|
principal_point=p1,
|
||||||
|
)
|
||||||
|
cam3 = camera_cls(
|
||||||
|
R=R1,
|
||||||
|
focal_length=f2,
|
||||||
|
principal_point=p1,
|
||||||
|
)
|
||||||
|
cam_batch = join_cameras_as_batch([cam0, cam1])
|
||||||
|
|
||||||
|
self.assertEqual(cam_batch._N, cam0._N + cam1._N)
|
||||||
|
self.assertEqual(cam_batch.device, cam0.device)
|
||||||
|
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
|
||||||
|
self.assertClose(cam_batch.principal_point, torch.cat((p0, p1), dim=0))
|
||||||
|
self.assertEqual(cam_batch._in_ndc, cam0._in_ndc)
|
||||||
|
|
||||||
|
# Test one broadcasted value and one fixed value
|
||||||
|
# Focal length as (N,) in one camera and (N, 2) in the other
|
||||||
|
cam_batch = join_cameras_as_batch([cam0, cam2])
|
||||||
|
self.assertEqual(cam_batch._N, cam0._N + cam2._N)
|
||||||
|
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
|
||||||
|
self.assertClose(
|
||||||
|
cam_batch.focal_length,
|
||||||
|
torch.cat([torch.tensor([[f0, f0]]).expand(6, -1), f1], dim=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Focal length as (N, 1) in one camera and (N, 2) in the other
|
||||||
|
cam_batch = join_cameras_as_batch([cam2, cam3])
|
||||||
|
self.assertClose(
|
||||||
|
cam_batch.focal_length,
|
||||||
|
torch.cat([f1, f2.expand(-1, 2)], dim=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_join_batch_perspective(self):
|
||||||
|
self.join_cameras_as_batch_fov(FoVPerspectiveCameras)
|
||||||
|
self.join_cameras_as_batch(PerspectiveCameras)
|
||||||
|
|
||||||
|
def test_join_batch_orthographic(self):
|
||||||
|
self.join_cameras_as_batch_fov(FoVOrthographicCameras)
|
||||||
|
self.join_cameras_as_batch(OrthographicCameras)
|
||||||
|
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# FoVPerspective Camera #
|
# FoVPerspective Camera #
|
||||||
@ -1055,7 +1149,7 @@ class TestOrthographicProjection(TestCaseMixin, unittest.TestCase):
|
|||||||
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
||||||
c135 = cam[index]
|
c135 = cam[index]
|
||||||
self.assertEqual(len(c135), 3)
|
self.assertEqual(len(c135), 3)
|
||||||
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
|
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
|
||||||
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
||||||
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
|
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
|
||||||
|
|
||||||
@ -1131,7 +1225,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
|
|||||||
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
index = torch.tensor([1, 3, 5], dtype=torch.int64)
|
||||||
c135 = cam[index]
|
c135 = cam[index]
|
||||||
self.assertEqual(len(c135), 3)
|
self.assertEqual(len(c135), 3)
|
||||||
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
|
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
|
||||||
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
|
||||||
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
|
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user