camera refactoring

Summary:
Refactor cameras
* CamerasBase was enhanced with `transform_points_screen` that transforms projected points from NDC to screen space
* OpenGLPerspective, OpenGLOrthographic -> FoVPerspective, FoVOrthographic
* SfMPerspective, SfMOrthographic -> Perspective, Orthographic
* PerspectiveCamera can optionally be constructred with screen space parameters
* Note on Cameras and coordinate systems was added

Reviewed By: nikhilaravi

Differential Revision: D23168525

fbshipit-source-id: dd138e2b2cc7e0e0d9f34c45b8251c01266a2063
This commit is contained in:
Georgia Gkioxari
2020-08-20 22:20:41 -07:00
committed by Facebook GitHub Bot
parent 9242e7e65d
commit 57a22e7306
65 changed files with 896 additions and 279 deletions

View File

@@ -10,7 +10,7 @@ from pytorch3d.renderer import (
HardPhongShader,
MeshRasterizer,
MeshRenderer,
OpenGLPerspectiveCameras,
FoVPerspectiveCameras,
PointLights,
RasterizationSettings,
TexturesVertex,
@@ -125,7 +125,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
meshes.textures = TexturesVertex(
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
)
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
cameras = kwargs.get("cameras", FoVPerspectiveCameras()).to(device)
if len(cameras) != 1 and len(cameras) % len(meshes) != 0:
raise ValueError("Mismatch between batch dims of cameras and meshes.")
if len(cameras) > 1:

View File

@@ -6,11 +6,15 @@ from .blending import (
sigmoid_alpha_blend,
softmax_rgb_blend,
)
from .cameras import OpenGLOrthographicCameras # deprecated
from .cameras import OpenGLPerspectiveCameras # deprecated
from .cameras import SfMOrthographicCameras # deprecated
from .cameras import SfMPerspectiveCameras # deprecated
from .cameras import (
OpenGLOrthographicCameras,
OpenGLPerspectiveCameras,
SfMOrthographicCameras,
SfMPerspectiveCameras,
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
camera_position_from_spherical_angles,
get_world_to_view_transform,
look_at_rotation,

View File

@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math
import warnings
from typing import Optional, Sequence, Tuple
import numpy as np
@@ -20,23 +21,43 @@ class CamerasBase(TensorProperties):
"""
`CamerasBase` implements a base class for all cameras.
For cameras, there are four different coordinate systems (or spaces)
- World coordinate system: This is the system the object lives - the world.
- Camera view coordinate system: This is the system that has its origin on the image plane
and the and the Z-axis perpendicular to the image plane.
In PyTorch3D, we assume that +X points left, and +Y points up and
+Z points out from the image plane.
The transformation from world -> view happens after applying a rotation (R)
and translation (T)
- NDC coordinate system: This is the normalized coordinate system that confines
in a volume the renderered part of the object or scene. Also known as view volume.
Given the PyTorch3D convention, (+1, +1, znear) is the top left near corner,
and (-1, -1, zfar) is the bottom right far corner of the volume.
The transformation from view -> NDC happens after applying the camera
projection matrix (P).
- Screen coordinate system: This is another representation of the view volume with
the XY coordinates defined in pixel space instead of a normalized space.
A better illustration of the coordinate systems can be found in pytorch3d/docs/notes/cameras.md.
It defines methods that are common to all camera models:
- `get_camera_center` that returns the optical center of the camera in
world coordinates
- `get_world_to_view_transform` which returns a 3D transform from
world coordinates to the camera coordinates
world coordinates to the camera view coordinates (R, T)
- `get_full_projection_transform` which composes the projection
transform with the world-to-view transform
- `transform_points` which takes a set of input points and
projects them onto a 2D camera plane.
transform (P) with the world-to-view transform (R, T)
- `transform_points` which takes a set of input points in world coordinates and
projects to NDC coordinates ranging from [-1, -1, znear] to [+1, +1, zfar].
- `transform_points_screen` which takes a set of input points in world coordinates and
projects them to the screen coordinates ranging from [0, 0, znear] to [W-1, H-1, zfar]
For each new camera, one should implement the `get_projection_transform`
routine that returns the mapping from camera coordinates in world units
to the screen coordinates.
routine that returns the mapping from camera view coordinates to NDC coordinates.
Another useful function that is specific to each camera model is
`unproject_points` which sends points from screen coordinates back to
camera or world coordinates depending on the `world_coordinates`
`unproject_points` which sends points from NDC coordinates back to
camera view or world coordinates depending on the `world_coordinates`
boolean argument of the function.
"""
@@ -56,7 +77,7 @@ class CamerasBase(TensorProperties):
def unproject_points(self):
"""
Transform input points in screen coodinates
Transform input points from NDC coodinates
to the world / camera coordinates.
Each of the input points `xy_depth` of shape (..., 3) is
@@ -74,7 +95,7 @@ class CamerasBase(TensorProperties):
cameras = # camera object derived from CamerasBase
xyz = # 3D points of shape (batch_size, num_points, 3)
# transform xyz to the camera coordinates
# transform xyz to the camera view coordinates
xyz_cam = cameras.get_world_to_view_transform().transform_points(xyz)
# extract the depth of each point as the 3rd coord of xyz_cam
depth = xyz_cam[:, :, 2:]
@@ -94,7 +115,7 @@ class CamerasBase(TensorProperties):
world_coordinates: If `True`, unprojects the points back to world
coordinates using the camera extrinsics `R` and `T`.
`False` ignores `R` and `T` and unprojects to
the camera coordinates.
the camera view coordinates.
Returns
new_points: unprojected points with the same shape as `xy_depth`.
@@ -141,7 +162,7 @@ class CamerasBase(TensorProperties):
lighting calculations.
Returns:
T: a Transform3d object which represents a batch of transforms
A Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
@@ -151,8 +172,8 @@ class CamerasBase(TensorProperties):
def get_full_projection_transform(self, **kwargs) -> Transform3d:
"""
Return the full world-to-screen transform composing the
world-to-view and view-to-screen transforms.
Return the full world-to-NDC transform composing the
world-to-view and view-to-NDC transforms.
Args:
**kwargs: parameters for the projection transforms can be passed in
@@ -164,26 +185,26 @@ class CamerasBase(TensorProperties):
lighting calculations.
Returns:
T: a Transform3d object which represents a batch of transforms
a Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
view_to_screen_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_screen_transform)
view_to_ndc_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_ndc_transform)
def transform_points(
self, points, eps: Optional[float] = None, **kwargs
) -> torch.Tensor:
"""
Transform input points from world to screen space.
Transform input points from world to NDC space.
Args:
points: torch tensor of shape (..., 3).
eps: If eps!=None, the argument is used to clamp the
divisor in the homogeneous normalization of the points
transformed to the screen space. Plese see
transformed to the ndc space. Please see
`transforms.Transform3D.transform_points` for details.
For `CamerasBase.transform_points`, setting `eps > 0`
@@ -194,8 +215,50 @@ class CamerasBase(TensorProperties):
Returns
new_points: transformed points with the same shape as the input.
"""
world_to_screen_transform = self.get_full_projection_transform(**kwargs)
return world_to_screen_transform.transform_points(points, eps=eps)
world_to_ndc_transform = self.get_full_projection_transform(**kwargs)
return world_to_ndc_transform.transform_points(points, eps=eps)
def transform_points_screen(
self, points, image_size, eps: Optional[float] = None, **kwargs
) -> torch.Tensor:
"""
Transform input points from world to screen space.
Args:
points: torch tensor of shape (N, V, 3).
image_size: torch tensor of shape (N, 2)
eps: If eps!=None, the argument is used to clamp the
divisor in the homogeneous normalization of the points
transformed to the ndc space. Please see
`transforms.Transform3D.transform_points` for details.
For `CamerasBase.transform_points`, setting `eps > 0`
stabilizes gradients since it leads to avoiding division
by excessivelly low numbers for points close to the
camera plane.
Returns
new_points: transformed points with the same shape as the input.
"""
ndc_points = self.transform_points(points, eps=eps, **kwargs)
if not torch.is_tensor(image_size):
image_size = torch.tensor(
image_size, dtype=torch.int64, device=points.device
)
if (image_size < 1).any():
raise ValueError("Provided image size is invalid.")
image_width, image_height = image_size.unbind(1)
image_width = image_width.view(-1, 1) # (N, 1)
image_height = image_height.view(-1, 1) # (N, 1)
ndc_z = ndc_points[..., 2]
screen_x = (image_width - 1.0) / 2.0 * (1.0 - ndc_points[..., 0])
screen_y = (image_height - 1.0) / 2.0 * (1.0 - ndc_points[..., 1])
return torch.stack((screen_x, screen_y, ndc_z), dim=2)
def clone(self):
"""
@@ -206,21 +269,56 @@ class CamerasBase(TensorProperties):
return super().clone(other)
########################
# Specific camera classes
########################
############################################################
# Field of View Camera Classes #
############################################################
class OpenGLPerspectiveCameras(CamerasBase):
def OpenGLPerspectiveCameras(
znear=1.0,
zfar=100.0,
aspect_ratio=1.0,
fov=60.0,
degrees: bool = True,
R=r,
T=t,
device="cpu",
):
"""
OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
Preserving OpenGLPerspectiveCameras for backward compatibility.
"""
warnings.warn(
"""OpenGLPerspectiveCameras is deprecated,
Use FoVPerspectiveCameras instead.
OpenGLPerspectiveCameras will be removed in future releases.""",
PendingDeprecationWarning,
)
return FoVPerspectiveCameras(
znear=znear,
zfar=zfar,
aspect_ratio=aspect_ratio,
fov=fov,
degrees=degrees,
R=R,
T=T,
device=device,
)
class FoVPerspectiveCameras(CamerasBase):
"""
A class which stores a batch of parameters to generate a batch of
projection matrices using the OpenGL convention for a perspective camera.
projection matrices by specifiying the field of view.
The definition of the parameters follow the OpenGL perspective camera.
The extrinsics of the camera (R and T matrices) can also be set in the
initializer or passed in to `get_full_projection_transform` to get
the full transformation from world -> screen.
the full transformation from world -> ndc.
The `transform_points` method calculates the full world -> screen transform
The `transform_points` method calculates the full world -> ndc transform
and then applies it to the input points.
The transforms can also be returned separately as Transform3d objects.
@@ -267,8 +365,11 @@ class OpenGLPerspectiveCameras(CamerasBase):
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the OpenGL perpective projection matrix with a symmetric
Calculate the perpective projection matrix with a symmetric
viewing frustrum. Use column major order.
The viewing frustrum will be projected into ndc, s.t.
(max_x, max_y) -> (+1, +1)
(min_x, min_y) -> (-1, -1)
Args:
**kwargs: parameters for the projection can be passed in as keyword
@@ -276,14 +377,14 @@ class OpenGLPerspectiveCameras(CamerasBase):
Return:
P: a Transform3d object which represents a batch of projection
matrices of shape (N, 3, 3)
matrices of shape (N, 4, 4)
.. code-block:: python
f1 = -(far + near)/(farnear)
f2 = -2*far*near/(far-near)
h1 = (top + bottom)/(top - bottom)
w1 = (right + left)/(right - left)
h1 = (max_y + min_y)/(max_y - min_y)
w1 = (max_x + min_x)/(max_x - min_x)
tanhalffov = tan((fov/2))
s1 = 1/tanhalffov
s2 = 1/(tanhalffov * (aspect_ratio))
@@ -310,10 +411,10 @@ class OpenGLPerspectiveCameras(CamerasBase):
if not torch.is_tensor(fov):
fov = torch.tensor(fov, device=self.device)
tanHalfFov = torch.tan((fov / 2))
top = tanHalfFov * znear
bottom = -top
right = top * aspect_ratio
left = -right
max_y = tanHalfFov * znear
min_y = -max_y
max_x = max_y * aspect_ratio
min_x = -max_x
# NOTE: In OpenGL the projection matrix changes the handedness of the
# coordinate frame. i.e the NDC space postive z direction is the
@@ -323,28 +424,19 @@ class OpenGLPerspectiveCameras(CamerasBase):
# so the so the z sign is 1.0.
z_sign = 1.0
P[:, 0, 0] = 2.0 * znear / (right - left)
P[:, 1, 1] = 2.0 * znear / (top - bottom)
P[:, 0, 2] = (right + left) / (right - left)
P[:, 1, 2] = (top + bottom) / (top - bottom)
P[:, 0, 0] = 2.0 * znear / (max_x - min_x)
P[:, 1, 1] = 2.0 * znear / (max_y - min_y)
P[:, 0, 2] = (max_x + min_x) / (max_x - min_x)
P[:, 1, 2] = (max_y + min_y) / (max_y - min_y)
P[:, 3, 2] = z_sign * ones
# NOTE: This part of the matrix is for z renormalization in OpenGL
# which maps the z to [-1, 1]. This won't work yet as the torch3d
# rasterizer ignores faces which have z < 0.
# P[:, 2, 2] = z_sign * (far + near) / (far - near)
# P[:, 2, 3] = -2.0 * far * near / (far - near)
# P[:, 3, 2] = z_sign * torch.ones((N))
# NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
# is at the near clipping plane and z = 1 when the point is at the far
# clipping plane. This replaces the OpenGL z normalization to [-1, 1]
# until rasterization is changed to clip at z = -1.
# clipping plane.
P[:, 2, 2] = z_sign * zfar / (zfar - znear)
P[:, 2, 3] = -(zfar * znear) / (zfar - znear)
# OpenGL uses column vectors so need to transpose the projection matrix
# as torch3d uses row vectors.
# Transpose the projection matrix as PyTorch3d transforms use row vectors.
transform = Transform3d(device=self.device)
transform._matrix = P.transpose(1, 2).contiguous()
return transform
@@ -357,7 +449,7 @@ class OpenGLPerspectiveCameras(CamerasBase):
**kwargs
) -> torch.Tensor:
""">!
OpenGL cameras further allow for passing depth in world units
FoV cameras further allow for passing depth in world units
(`scaled_depth_input=False`) or in the [0, 1]-normalized units
(`scaled_depth_input=True`)
@@ -367,11 +459,11 @@ class OpenGLPerspectiveCameras(CamerasBase):
the world units.
"""
# obtain the relevant transformation to screen
# obtain the relevant transformation to ndc
if world_coordinates:
to_screen_transform = self.get_full_projection_transform()
to_ndc_transform = self.get_full_projection_transform()
else:
to_screen_transform = self.get_projection_transform()
to_ndc_transform = self.get_projection_transform()
if scaled_depth_input:
# the input is scaled depth, so we don't have to do anything
@@ -390,45 +482,84 @@ class OpenGLPerspectiveCameras(CamerasBase):
xy_sdepth = torch.cat((xy_depth[..., 0:2], sdepth), dim=-1)
# unproject with inverse of the projection
unprojection_transform = to_screen_transform.inverse()
unprojection_transform = to_ndc_transform.inverse()
return unprojection_transform.transform_points(xy_sdepth)
class OpenGLOrthographicCameras(CamerasBase):
def OpenGLOrthographicCameras(
znear=1.0,
zfar=100.0,
top=1.0,
bottom=-1.0,
left=-1.0,
right=1.0,
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
R=r,
T=t,
device="cpu",
):
"""
OpenGLOrthographicCameras has been DEPRECATED. Use FoVOrthographicCameras instead.
Preserving OpenGLOrthographicCameras for backward compatibility.
"""
warnings.warn(
"""OpenGLOrthographicCameras is deprecated,
Use FoVOrthographicCameras instead.
OpenGLOrthographicCameras will be removed in future releases.""",
PendingDeprecationWarning,
)
return FoVOrthographicCameras(
znear=znear,
zfar=zfar,
max_y=top,
min_y=bottom,
max_x=right,
min_x=left,
scale_xyz=scale_xyz,
R=R,
T=T,
device=device,
)
class FoVOrthographicCameras(CamerasBase):
"""
A class which stores a batch of parameters to generate a batch of
transformation matrices using the OpenGL convention for orthographic camera.
projection matrices by specifiying the field of view.
The definition of the parameters follow the OpenGL orthographic camera.
"""
def __init__(
self,
znear=1.0,
zfar=100.0,
top=1.0,
bottom=-1.0,
left=-1.0,
right=1.0,
max_y=1.0,
min_y=-1.0,
max_x=1.0,
min_x=-1.0,
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
R=r,
T=t,
device="cpu",
):
"""
__init__(self, znear, zfar, top, bottom, left, right, scale_xyz, R, T, device) -> None # noqa
__init__(self, znear, zfar, max_y, min_y, max_x, min_x, scale_xyz, R, T, device) -> None # noqa
Args:
znear: near clipping plane of the view frustrum.
zfar: far clipping plane of the view frustrum.
top: position of the top of the screen.
bottom: position of the bottom of the screen.
left: position of the left of the screen.
right: position of the right of the screen.
max_y: maximum y coordinate of the frustrum.
min_y: minimum y coordinate of the frustrum.
max_x: maximum x coordinate of the frustrum.
min_x: minumum x coordinage of the frustrum
scale_xyz: scale factors for each axis of shape (N, 3).
R: Rotation matrix of shape (N, 3, 3).
T: Translation of shape (N, 3).
device: torch.device or string.
Only need to set left, right, top, bottom for viewing frustrums
Only need to set min_x, max_x, min_y, max_y for viewing frustrums
which are non symmetric about the origin.
"""
# The initializer formats all inputs to torch tensors and broadcasts
@@ -437,10 +568,10 @@ class OpenGLOrthographicCameras(CamerasBase):
device=device,
znear=znear,
zfar=zfar,
top=top,
bottom=bottom,
left=left,
right=right,
max_y=max_y,
min_y=min_y,
max_x=max_x,
min_x=min_x,
scale_xyz=scale_xyz,
R=R,
T=T,
@@ -448,7 +579,7 @@ class OpenGLOrthographicCameras(CamerasBase):
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the OpenGL orthographic projection matrix.
Calculate the orthographic projection matrix.
Use column major order.
Args:
@@ -456,16 +587,16 @@ class OpenGLOrthographicCameras(CamerasBase):
override the default values set in __init__.
Return:
P: a Transform3d object which represents a batch of projection
matrices of shape (N, 3, 3)
matrices of shape (N, 4, 4)
.. code-block:: python
scale_x = 2/(right - left)
scale_y = 2/(top - bottom)
scale_z = 2/(far-near)
mid_x = (right + left)/(right - left)
mix_y = (top + bottom)/(top - bottom)
mid_z = (far + near)/(farnear)
scale_x = 2 / (max_x - min_x)
scale_y = 2 / (max_y - min_y)
scale_z = 2 / (far-near)
mid_x = (max_x + min_x) / (max_x - min_x)
mix_y = (max_y + min_y) / (max_y - min_y)
mid_z = (far + near) / (farnear)
P = [
[scale_x, 0, 0, -mid_x],
@@ -476,10 +607,10 @@ class OpenGLOrthographicCameras(CamerasBase):
"""
znear = kwargs.get("znear", self.znear) # pyre-ignore[16]
zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16]
left = kwargs.get("left", self.left) # pyre-ignore[16]
right = kwargs.get("right", self.right) # pyre-ignore[16]
top = kwargs.get("top", self.top) # pyre-ignore[16]
bottom = kwargs.get("bottom", self.bottom) # pyre-ignore[16]
max_x = kwargs.get("max_x", self.max_x) # pyre-ignore[16]
min_x = kwargs.get("min_x", self.min_x) # pyre-ignore[16]
max_y = kwargs.get("max_y", self.max_y) # pyre-ignore[16]
min_y = kwargs.get("min_y", self.min_y) # pyre-ignore[16]
scale_xyz = kwargs.get("scale_xyz", self.scale_xyz) # pyre-ignore[16]
P = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device)
@@ -489,10 +620,10 @@ class OpenGLOrthographicCameras(CamerasBase):
# right handed coordinate system throughout.
z_sign = +1.0
P[:, 0, 0] = (2.0 / (right - left)) * scale_xyz[:, 0]
P[:, 1, 1] = (2.0 / (top - bottom)) * scale_xyz[:, 1]
P[:, 0, 3] = -(right + left) / (right - left)
P[:, 1, 3] = -(top + bottom) / (top - bottom)
P[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
P[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
P[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
P[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
P[:, 3, 3] = ones
# NOTE: This maps the z coordinate to the range [0, 1] and replaces the
@@ -500,12 +631,6 @@ class OpenGLOrthographicCameras(CamerasBase):
P[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
P[:, 2, 3] = -znear / (zfar - znear)
# NOTE: This part of the matrix is for z renormalization in OpenGL.
# The z is mapped to the range [-1, 1] but this won't work yet in
# pytorch3d as the rasterizer ignores faces which have z < 0.
# P[:, 2, 2] = z_sign * (2.0 / (far - near)) * scale[:, 2]
# P[:, 2, 3] = -(far + near) / (far - near)
transform = Transform3d(device=self.device)
transform._matrix = P.transpose(1, 2).contiguous()
return transform
@@ -518,7 +643,7 @@ class OpenGLOrthographicCameras(CamerasBase):
**kwargs
) -> torch.Tensor:
""">!
OpenGL cameras further allow for passing depth in world units
FoV cameras further allow for passing depth in world units
(`scaled_depth_input=False`) or in the [0, 1]-normalized units
(`scaled_depth_input=True`)
@@ -529,9 +654,9 @@ class OpenGLOrthographicCameras(CamerasBase):
"""
if world_coordinates:
to_screen_transform = self.get_full_projection_transform(**kwargs.copy())
to_ndc_transform = self.get_full_projection_transform(**kwargs.copy())
else:
to_screen_transform = self.get_projection_transform(**kwargs.copy())
to_ndc_transform = self.get_projection_transform(**kwargs.copy())
if scaled_depth_input:
# the input depth is already scaled
@@ -547,22 +672,88 @@ class OpenGLOrthographicCameras(CamerasBase):
# cat xy and scaled depth
xy_sdepth = torch.cat((xy_depth[..., :2], scaled_depth), dim=-1)
# finally invert the transform
unprojection_transform = to_screen_transform.inverse()
unprojection_transform = to_ndc_transform.inverse()
return unprojection_transform.transform_points(xy_sdepth)
class SfMPerspectiveCameras(CamerasBase):
############################################################
# MultiView Camera Classes #
############################################################
"""
Note that the MultiView Cameras accept parameters in both
screen and NDC space.
If the user specifies `image_size` at construction time then
we assume the parameters are in screen space.
"""
def SfMPerspectiveCameras(
focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
):
"""
SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead.
Preserving SfMPerspectiveCameras for backward compatibility.
"""
warnings.warn(
"""SfMPerspectiveCameras is deprecated,
Use PerspectiveCameras instead.
SfMPerspectiveCameras will be removed in future releases.""",
PendingDeprecationWarning,
)
return PerspectiveCameras(
focal_length=focal_length,
principal_point=principal_point,
R=R,
T=T,
device=device,
)
class PerspectiveCameras(CamerasBase):
"""
A class which stores a batch of parameters to generate a batch of
transformation matrices using the multi-view geometry convention for
perspective camera.
Parameters for this camera can be specified in NDC or in screen space.
If you wish to provide parameters in screen space, you NEED to provide
the image_size = (imwidth, imheight).
If you wish to provide parameters in NDC space, you should NOT provide
image_size. Providing valid image_size will triger a screen space to
NDC space transformation in the camera.
For example, here is how to define cameras on the two spaces.
.. code-block:: python
# camera defined in screen space
cameras = PerspectiveCameras(
focal_length=((22.0, 15.0),), # (fx_screen, fy_screen)
principal_point=((192.0, 128.0),), # (px_screen, py_screen)
image_size=((256, 256),), # (imwidth, imheight)
)
# the equivalent camera defined in NDC space
cameras = PerspectiveCameras(
focal_length=((0.17875, 0.11718),), # fx = fx_screen / half_imwidth,
# fy = fy_screen / half_imheight
principal_point=((-0.5, 0),), # px = - (px_screen - half_imwidth) / half_imwidth,
# py = - (py_screen - half_imheight) / half_imheight
)
"""
def __init__(
self, focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
self,
focal_length=1.0,
principal_point=((0.0, 0.0),),
R=r,
T=t,
device="cpu",
image_size=((-1, -1),),
):
"""
__init__(self, focal_length, principal_point, R, T, device) -> None
__init__(self, focal_length, principal_point, R, T, device, image_size) -> None
Args:
focal_length: Focal length of the camera in world units.
@@ -574,6 +765,11 @@ class SfMPerspectiveCameras(CamerasBase):
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
device: torch.device or string
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
is provided, the camera parameters are assumed to be in screen
space. They will be converted to NDC space.
If image_size is not provided, the parameters are assumed to
be in NDC space.
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
@@ -583,6 +779,7 @@ class SfMPerspectiveCameras(CamerasBase):
principal_point=principal_point,
R=R,
T=T,
image_size=image_size,
)
def get_projection_transform(self, **kwargs) -> Transform3d:
@@ -615,9 +812,20 @@ class SfMPerspectiveCameras(CamerasBase):
principal_point = kwargs.get("principal_point", self.principal_point)
# pyre-ignore[16]
focal_length = kwargs.get("focal_length", self.focal_length)
# pyre-ignore[16]
image_size = kwargs.get("image_size", self.image_size)
# if imwidth > 0, parameters are in screen space
in_screen = image_size[0][0] > 0
image_size = image_size if in_screen else None
P = _get_sfm_calibration_matrix(
self._N, self.device, focal_length, principal_point, False
self._N,
self.device,
focal_length,
principal_point,
orthographic=False,
image_size=image_size,
)
transform = Transform3d(device=self.device)
@@ -628,29 +836,83 @@ class SfMPerspectiveCameras(CamerasBase):
self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs
) -> torch.Tensor:
if world_coordinates:
to_screen_transform = self.get_full_projection_transform(**kwargs)
to_ndc_transform = self.get_full_projection_transform(**kwargs)
else:
to_screen_transform = self.get_projection_transform(**kwargs)
to_ndc_transform = self.get_projection_transform(**kwargs)
unprojection_transform = to_screen_transform.inverse()
unprojection_transform = to_ndc_transform.inverse()
xy_inv_depth = torch.cat(
(xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]), dim=-1 # type: ignore
)
return unprojection_transform.transform_points(xy_inv_depth)
class SfMOrthographicCameras(CamerasBase):
def SfMOrthographicCameras(
focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
):
"""
SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead.
Preserving SfMOrthographicCameras for backward compatibility.
"""
warnings.warn(
"""SfMOrthographicCameras is deprecated,
Use OrthographicCameras instead.
SfMOrthographicCameras will be removed in future releases.""",
PendingDeprecationWarning,
)
return OrthographicCameras(
focal_length=focal_length,
principal_point=principal_point,
R=R,
T=T,
device=device,
)
class OrthographicCameras(CamerasBase):
"""
A class which stores a batch of parameters to generate a batch of
transformation matrices using the multi-view geometry convention for
orthographic camera.
Parameters for this camera can be specified in NDC or in screen space.
If you wish to provide parameters in screen space, you NEED to provide
the image_size = (imwidth, imheight).
If you wish to provide parameters in NDC space, you should NOT provide
image_size. Providing valid image_size will triger a screen space to
NDC space transformation in the camera.
For example, here is how to define cameras on the two spaces.
.. code-block:: python
# camera defined in screen space
cameras = OrthographicCameras(
focal_length=((22.0, 15.0),), # (fx, fy)
principal_point=((192.0, 128.0),), # (px, py)
image_size=((256, 256),), # (imwidth, imheight)
)
# the equivalent camera defined in NDC space
cameras = OrthographicCameras(
focal_length=((0.17875, 0.11718),), # := (fx / half_imwidth, fy / half_imheight)
principal_point=((-0.5, 0),), # := (- (px - half_imwidth) / half_imwidth,
- (py - half_imheight) / half_imheight)
)
"""
def __init__(
self, focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
self,
focal_length=1.0,
principal_point=((0.0, 0.0),),
R=r,
T=t,
device="cpu",
image_size=((-1, -1),),
):
"""
__init__(self, focal_length, principal_point, R, T, device) -> None
__init__(self, focal_length, principal_point, R, T, device, image_size) -> None
Args:
focal_length: Focal length of the camera in world units.
@@ -662,6 +924,11 @@ class SfMOrthographicCameras(CamerasBase):
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
device: torch.device or string
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
is provided, the camera parameters are assumed to be in screen
space. They will be converted to NDC space.
If image_size is not provided, the parameters are assumed to
be in NDC space.
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
@@ -671,6 +938,7 @@ class SfMOrthographicCameras(CamerasBase):
principal_point=principal_point,
R=R,
T=T,
image_size=image_size,
)
def get_projection_transform(self, **kwargs) -> Transform3d:
@@ -703,9 +971,20 @@ class SfMOrthographicCameras(CamerasBase):
principal_point = kwargs.get("principal_point", self.principal_point)
# pyre-ignore[16]
focal_length = kwargs.get("focal_length", self.focal_length)
# pyre-ignore[16]
image_size = kwargs.get("image_size", self.image_size)
# if imwidth > 0, parameters are in screen space
in_screen = image_size[0][0] > 0
image_size = image_size if in_screen else None
P = _get_sfm_calibration_matrix(
self._N, self.device, focal_length, principal_point, True
self._N,
self.device,
focal_length,
principal_point,
orthographic=True,
image_size=image_size,
)
transform = Transform3d(device=self.device)
@@ -716,17 +995,26 @@ class SfMOrthographicCameras(CamerasBase):
self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs
) -> torch.Tensor:
if world_coordinates:
to_screen_transform = self.get_full_projection_transform(**kwargs)
to_ndc_transform = self.get_full_projection_transform(**kwargs)
else:
to_screen_transform = self.get_projection_transform(**kwargs)
to_ndc_transform = self.get_projection_transform(**kwargs)
unprojection_transform = to_screen_transform.inverse()
unprojection_transform = to_ndc_transform.inverse()
return unprojection_transform.transform_points(xy_depth)
# SfMCameras helper
################################################
# Helper functions for cameras #
################################################
def _get_sfm_calibration_matrix(
N, device, focal_length, principal_point, orthographic: bool
N,
device,
focal_length,
principal_point,
orthographic: bool = False,
image_size=None,
) -> torch.Tensor:
"""
Returns a calibration matrix of a perspective/orthograpic camera.
@@ -736,6 +1024,10 @@ def _get_sfm_calibration_matrix(
focal_length: Focal length of the camera in world units.
principal_point: xy coordinates of the center of
the principal point of the camera in pixels.
orthographic: Boolean specifying if the camera is orthographic or not
image_size: (Optional) Specifying the image_size = (imwidth, imheight).
If not None, the camera parameters are assumed to be in screen space
and are transformed to NDC space.
The calibration matrix `K` is set up as follows:
@@ -769,7 +1061,7 @@ def _get_sfm_calibration_matrix(
if not torch.is_tensor(focal_length):
focal_length = torch.tensor(focal_length, device=device)
if len(focal_length.shape) in (0, 1) or focal_length.shape[1] == 1:
if focal_length.ndim in (0, 1) or focal_length.shape[1] == 1:
fx = fy = focal_length
else:
fx, fy = focal_length.unbind(1)
@@ -779,6 +1071,22 @@ def _get_sfm_calibration_matrix(
px, py = principal_point.unbind(1)
if image_size is not None:
if not torch.is_tensor(image_size):
image_size = torch.tensor(image_size, device=device)
imwidth, imheight = image_size.unbind(1)
# make sure imwidth, imheight are valid (>0)
if (imwidth < 1).any() or (imheight < 1).any():
raise ValueError(
"Camera parameters provided in screen space. Image width or height invalid."
)
half_imwidth = imwidth / 2.0
half_imheight = imheight / 2.0
fx = fx / half_imwidth
fy = fy / half_imheight
px = -(px - half_imwidth) / half_imwidth
py = -(py - half_imheight) / half_imheight
K = fx.new_zeros(N, 4, 4)
K[:, 0, 0] = fx
K[:, 1, 1] = fy