mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
Join cameras as batch
Summary: Function to join a list of cameras objects into a single batched object. FB: In the next diff I will remove the `concatenate_cameras` function in implicitron and update the callsites. Reviewed By: nikhilaravi Differential Revision: D33198209 fbshipit-source-id: 0c9f5f5df498a0def9dba756c984e6a946618158
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9e2bc3a17f
commit
39bb2ce063
@@ -10,7 +10,7 @@ from .blending import (
|
||||
sigmoid_alpha_blend,
|
||||
softmax_rgb_blend,
|
||||
)
|
||||
from .camera_utils import rotate_on_spot
|
||||
from .camera_utils import join_cameras_as_batch, rotate_on_spot
|
||||
from .cameras import OpenGLOrthographicCameras # deprecated
|
||||
from .cameras import OpenGLPerspectiveCameras # deprecated
|
||||
from .cameras import SfMOrthographicCameras # deprecated
|
||||
@@ -29,6 +29,7 @@ from .implicit import (
|
||||
AbsorptionOnlyRaymarcher,
|
||||
EmissionAbsorptionRaymarcher,
|
||||
GridRaysampler,
|
||||
HarmonicEmbedding,
|
||||
ImplicitRenderer,
|
||||
MonteCarloRaysampler,
|
||||
NDCGridRaysampler,
|
||||
@@ -37,7 +38,6 @@ from .implicit import (
|
||||
VolumeSampler,
|
||||
ray_bundle_to_ray_points,
|
||||
ray_bundle_variables_to_ray_points,
|
||||
HarmonicEmbedding,
|
||||
)
|
||||
from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular
|
||||
from .materials import Materials
|
||||
|
||||
@@ -4,11 +4,13 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.transforms import Transform3d
|
||||
|
||||
from .cameras import CamerasBase
|
||||
|
||||
|
||||
def camera_to_eye_at_up(
|
||||
world_to_view_transform: Transform3d,
|
||||
@@ -141,3 +143,65 @@ def rotate_on_spot(
|
||||
new_T = torch.matmul(new_R.transpose(1, 2), old_RT)[:, :, 0]
|
||||
|
||||
return new_R, new_T
|
||||
|
||||
|
||||
def join_cameras_as_batch(cameras_list: Sequence[CamerasBase]) -> CamerasBase:
|
||||
"""
|
||||
Create a batched cameras object by concatenating a list of input
|
||||
cameras objects. All the tensor attributes will be joined along
|
||||
the batch dimension.
|
||||
|
||||
Args:
|
||||
cameras_list: List of camera classes all of the same type and
|
||||
on the same device. Each represents one or more cameras.
|
||||
Returns:
|
||||
cameras: single batched cameras object of the same
|
||||
type as all the objects in the input list.
|
||||
"""
|
||||
# Get the type and fields to join from the first camera in the batch
|
||||
c0 = cameras_list[0]
|
||||
fields = c0._FIELDS
|
||||
shared_fields = c0._SHARED_FIELDS
|
||||
|
||||
if not all(isinstance(c, CamerasBase) for c in cameras_list):
|
||||
raise ValueError("cameras in cameras_list must inherit from CamerasBase")
|
||||
|
||||
if not all(type(c) is type(c0) for c in cameras_list[1:]):
|
||||
raise ValueError("All cameras must be of the same type")
|
||||
|
||||
if not all(c.device == c0.device for c in cameras_list[1:]):
|
||||
raise ValueError("All cameras in the batch must be on the same device")
|
||||
|
||||
# Concat the fields to make a batched tensor
|
||||
kwargs = {}
|
||||
kwargs["device"] = c0.device
|
||||
|
||||
for field in fields:
|
||||
field_not_none = [(getattr(c, field) is not None) for c in cameras_list]
|
||||
if not any(field_not_none):
|
||||
continue
|
||||
if not all(field_not_none):
|
||||
raise ValueError(f"Attribute {field} is inconsistently present")
|
||||
|
||||
attrs_list = [getattr(c, field) for c in cameras_list]
|
||||
|
||||
if field in shared_fields:
|
||||
# Only needs to be set once
|
||||
if not all(a == attrs_list[0] for a in attrs_list):
|
||||
raise ValueError(f"Attribute {field} is not constant across inputs")
|
||||
|
||||
# e.g. "in_ndc" is set as attribute "_in_ndc" on the class
|
||||
# but provided as "in_ndc" in the input args
|
||||
if field.startswith("_"):
|
||||
field = field[1:]
|
||||
|
||||
kwargs[field] = attrs_list[0]
|
||||
elif isinstance(attrs_list[0], torch.Tensor):
|
||||
# In the init, all inputs will be converted to
|
||||
# batched tensors before set as attributes
|
||||
# Join as a tensor along the batch dimension
|
||||
kwargs[field] = torch.cat(attrs_list, dim=0)
|
||||
else:
|
||||
raise ValueError(f"Field {field} type is not supported for batching")
|
||||
|
||||
return c0.__class__(**kwargs)
|
||||
|
||||
@@ -77,7 +77,12 @@ class CamerasBase(TensorProperties):
|
||||
|
||||
# Used in __getitem__ to index the relevant fields
|
||||
# When creating a new camera, this should be set in the __init__
|
||||
_FIELDS: Tuple = ()
|
||||
_FIELDS: Tuple[str, ...] = ()
|
||||
|
||||
# Names of fields which are a constant property of the whole batch, rather
|
||||
# than themselves a batch of data.
|
||||
# When joining objects into a batch, they will have to agree.
|
||||
_SHARED_FIELDS: Tuple[str, ...] = ()
|
||||
|
||||
def get_projection_transform(self):
|
||||
"""
|
||||
@@ -499,6 +504,8 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
"degrees",
|
||||
)
|
||||
|
||||
_SHARED_FIELDS = ("degrees",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
znear=1.0,
|
||||
@@ -997,6 +1004,8 @@ class PerspectiveCameras(CamerasBase):
|
||||
"image_size",
|
||||
)
|
||||
|
||||
_SHARED_FIELDS = ("_in_ndc",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
focal_length=1.0,
|
||||
@@ -1047,6 +1056,12 @@ class PerspectiveCameras(CamerasBase):
|
||||
else:
|
||||
self.image_size = None
|
||||
|
||||
# When focal length is provided as one value, expand to
|
||||
# create (N, 2) shape tensor
|
||||
if self.focal_length.ndim == 1: # (N,)
|
||||
self.focal_length = self.focal_length[:, None] # (N, 1)
|
||||
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Calculate the projection matrix using the
|
||||
@@ -1227,6 +1242,8 @@ class OrthographicCameras(CamerasBase):
|
||||
"image_size",
|
||||
)
|
||||
|
||||
_SHARED_FIELDS = ("_in_ndc",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
focal_length=1.0,
|
||||
@@ -1276,6 +1293,12 @@ class OrthographicCameras(CamerasBase):
|
||||
else:
|
||||
self.image_size = None
|
||||
|
||||
# When focal length is provided as one value, expand to
|
||||
# create (N, 2) shape tensor
|
||||
if self.focal_length.ndim == 1: # (N,)
|
||||
self.focal_length = self.focal_length[:, None] # (N, 1)
|
||||
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Calculate the projection matrix using
|
||||
|
||||
Reference in New Issue
Block a user