mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
camera refactoring
Summary: Refactor cameras * CamerasBase was enhanced with `transform_points_screen` that transforms projected points from NDC to screen space * OpenGLPerspective, OpenGLOrthographic -> FoVPerspective, FoVOrthographic * SfMPerspective, SfMOrthographic -> Perspective, Orthographic * PerspectiveCamera can optionally be constructred with screen space parameters * Note on Cameras and coordinate systems was added Reviewed By: nikhilaravi Differential Revision: D23168525 fbshipit-source-id: dd138e2b2cc7e0e0d9f34c45b8251c01266a2063
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9242e7e65d
commit
57a22e7306
@@ -10,7 +10,7 @@ from pytorch3d.renderer import (
|
||||
HardPhongShader,
|
||||
MeshRasterizer,
|
||||
MeshRenderer,
|
||||
OpenGLPerspectiveCameras,
|
||||
FoVPerspectiveCameras,
|
||||
PointLights,
|
||||
RasterizationSettings,
|
||||
TexturesVertex,
|
||||
@@ -125,7 +125,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
meshes.textures = TexturesVertex(
|
||||
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
|
||||
)
|
||||
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
|
||||
cameras = kwargs.get("cameras", FoVPerspectiveCameras()).to(device)
|
||||
if len(cameras) != 1 and len(cameras) % len(meshes) != 0:
|
||||
raise ValueError("Mismatch between batch dims of cameras and meshes.")
|
||||
if len(cameras) > 1:
|
||||
|
||||
@@ -6,11 +6,15 @@ from .blending import (
|
||||
sigmoid_alpha_blend,
|
||||
softmax_rgb_blend,
|
||||
)
|
||||
from .cameras import OpenGLOrthographicCameras # deprecated
|
||||
from .cameras import OpenGLPerspectiveCameras # deprecated
|
||||
from .cameras import SfMOrthographicCameras # deprecated
|
||||
from .cameras import SfMPerspectiveCameras # deprecated
|
||||
from .cameras import (
|
||||
OpenGLOrthographicCameras,
|
||||
OpenGLPerspectiveCameras,
|
||||
SfMOrthographicCameras,
|
||||
SfMPerspectiveCameras,
|
||||
FoVOrthographicCameras,
|
||||
FoVPerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
camera_position_from_spherical_angles,
|
||||
get_world_to_view_transform,
|
||||
look_at_rotation,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -20,23 +21,43 @@ class CamerasBase(TensorProperties):
|
||||
"""
|
||||
`CamerasBase` implements a base class for all cameras.
|
||||
|
||||
For cameras, there are four different coordinate systems (or spaces)
|
||||
- World coordinate system: This is the system the object lives - the world.
|
||||
- Camera view coordinate system: This is the system that has its origin on the image plane
|
||||
and the and the Z-axis perpendicular to the image plane.
|
||||
In PyTorch3D, we assume that +X points left, and +Y points up and
|
||||
+Z points out from the image plane.
|
||||
The transformation from world -> view happens after applying a rotation (R)
|
||||
and translation (T)
|
||||
- NDC coordinate system: This is the normalized coordinate system that confines
|
||||
in a volume the renderered part of the object or scene. Also known as view volume.
|
||||
Given the PyTorch3D convention, (+1, +1, znear) is the top left near corner,
|
||||
and (-1, -1, zfar) is the bottom right far corner of the volume.
|
||||
The transformation from view -> NDC happens after applying the camera
|
||||
projection matrix (P).
|
||||
- Screen coordinate system: This is another representation of the view volume with
|
||||
the XY coordinates defined in pixel space instead of a normalized space.
|
||||
|
||||
A better illustration of the coordinate systems can be found in pytorch3d/docs/notes/cameras.md.
|
||||
|
||||
It defines methods that are common to all camera models:
|
||||
- `get_camera_center` that returns the optical center of the camera in
|
||||
world coordinates
|
||||
- `get_world_to_view_transform` which returns a 3D transform from
|
||||
world coordinates to the camera coordinates
|
||||
world coordinates to the camera view coordinates (R, T)
|
||||
- `get_full_projection_transform` which composes the projection
|
||||
transform with the world-to-view transform
|
||||
- `transform_points` which takes a set of input points and
|
||||
projects them onto a 2D camera plane.
|
||||
transform (P) with the world-to-view transform (R, T)
|
||||
- `transform_points` which takes a set of input points in world coordinates and
|
||||
projects to NDC coordinates ranging from [-1, -1, znear] to [+1, +1, zfar].
|
||||
- `transform_points_screen` which takes a set of input points in world coordinates and
|
||||
projects them to the screen coordinates ranging from [0, 0, znear] to [W-1, H-1, zfar]
|
||||
|
||||
For each new camera, one should implement the `get_projection_transform`
|
||||
routine that returns the mapping from camera coordinates in world units
|
||||
to the screen coordinates.
|
||||
routine that returns the mapping from camera view coordinates to NDC coordinates.
|
||||
|
||||
Another useful function that is specific to each camera model is
|
||||
`unproject_points` which sends points from screen coordinates back to
|
||||
camera or world coordinates depending on the `world_coordinates`
|
||||
`unproject_points` which sends points from NDC coordinates back to
|
||||
camera view or world coordinates depending on the `world_coordinates`
|
||||
boolean argument of the function.
|
||||
"""
|
||||
|
||||
@@ -56,7 +77,7 @@ class CamerasBase(TensorProperties):
|
||||
|
||||
def unproject_points(self):
|
||||
"""
|
||||
Transform input points in screen coodinates
|
||||
Transform input points from NDC coodinates
|
||||
to the world / camera coordinates.
|
||||
|
||||
Each of the input points `xy_depth` of shape (..., 3) is
|
||||
@@ -74,7 +95,7 @@ class CamerasBase(TensorProperties):
|
||||
|
||||
cameras = # camera object derived from CamerasBase
|
||||
xyz = # 3D points of shape (batch_size, num_points, 3)
|
||||
# transform xyz to the camera coordinates
|
||||
# transform xyz to the camera view coordinates
|
||||
xyz_cam = cameras.get_world_to_view_transform().transform_points(xyz)
|
||||
# extract the depth of each point as the 3rd coord of xyz_cam
|
||||
depth = xyz_cam[:, :, 2:]
|
||||
@@ -94,7 +115,7 @@ class CamerasBase(TensorProperties):
|
||||
world_coordinates: If `True`, unprojects the points back to world
|
||||
coordinates using the camera extrinsics `R` and `T`.
|
||||
`False` ignores `R` and `T` and unprojects to
|
||||
the camera coordinates.
|
||||
the camera view coordinates.
|
||||
|
||||
Returns
|
||||
new_points: unprojected points with the same shape as `xy_depth`.
|
||||
@@ -141,7 +162,7 @@ class CamerasBase(TensorProperties):
|
||||
lighting calculations.
|
||||
|
||||
Returns:
|
||||
T: a Transform3d object which represents a batch of transforms
|
||||
A Transform3d object which represents a batch of transforms
|
||||
of shape (N, 3, 3)
|
||||
"""
|
||||
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
|
||||
@@ -151,8 +172,8 @@ class CamerasBase(TensorProperties):
|
||||
|
||||
def get_full_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Return the full world-to-screen transform composing the
|
||||
world-to-view and view-to-screen transforms.
|
||||
Return the full world-to-NDC transform composing the
|
||||
world-to-view and view-to-NDC transforms.
|
||||
|
||||
Args:
|
||||
**kwargs: parameters for the projection transforms can be passed in
|
||||
@@ -164,26 +185,26 @@ class CamerasBase(TensorProperties):
|
||||
lighting calculations.
|
||||
|
||||
Returns:
|
||||
T: a Transform3d object which represents a batch of transforms
|
||||
a Transform3d object which represents a batch of transforms
|
||||
of shape (N, 3, 3)
|
||||
"""
|
||||
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
|
||||
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
|
||||
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
|
||||
view_to_screen_transform = self.get_projection_transform(**kwargs)
|
||||
return world_to_view_transform.compose(view_to_screen_transform)
|
||||
view_to_ndc_transform = self.get_projection_transform(**kwargs)
|
||||
return world_to_view_transform.compose(view_to_ndc_transform)
|
||||
|
||||
def transform_points(
|
||||
self, points, eps: Optional[float] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Transform input points from world to screen space.
|
||||
Transform input points from world to NDC space.
|
||||
|
||||
Args:
|
||||
points: torch tensor of shape (..., 3).
|
||||
eps: If eps!=None, the argument is used to clamp the
|
||||
divisor in the homogeneous normalization of the points
|
||||
transformed to the screen space. Plese see
|
||||
transformed to the ndc space. Please see
|
||||
`transforms.Transform3D.transform_points` for details.
|
||||
|
||||
For `CamerasBase.transform_points`, setting `eps > 0`
|
||||
@@ -194,8 +215,50 @@ class CamerasBase(TensorProperties):
|
||||
Returns
|
||||
new_points: transformed points with the same shape as the input.
|
||||
"""
|
||||
world_to_screen_transform = self.get_full_projection_transform(**kwargs)
|
||||
return world_to_screen_transform.transform_points(points, eps=eps)
|
||||
world_to_ndc_transform = self.get_full_projection_transform(**kwargs)
|
||||
return world_to_ndc_transform.transform_points(points, eps=eps)
|
||||
|
||||
def transform_points_screen(
|
||||
self, points, image_size, eps: Optional[float] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Transform input points from world to screen space.
|
||||
|
||||
Args:
|
||||
points: torch tensor of shape (N, V, 3).
|
||||
image_size: torch tensor of shape (N, 2)
|
||||
eps: If eps!=None, the argument is used to clamp the
|
||||
divisor in the homogeneous normalization of the points
|
||||
transformed to the ndc space. Please see
|
||||
`transforms.Transform3D.transform_points` for details.
|
||||
|
||||
For `CamerasBase.transform_points`, setting `eps > 0`
|
||||
stabilizes gradients since it leads to avoiding division
|
||||
by excessivelly low numbers for points close to the
|
||||
camera plane.
|
||||
|
||||
Returns
|
||||
new_points: transformed points with the same shape as the input.
|
||||
"""
|
||||
|
||||
ndc_points = self.transform_points(points, eps=eps, **kwargs)
|
||||
|
||||
if not torch.is_tensor(image_size):
|
||||
image_size = torch.tensor(
|
||||
image_size, dtype=torch.int64, device=points.device
|
||||
)
|
||||
if (image_size < 1).any():
|
||||
raise ValueError("Provided image size is invalid.")
|
||||
|
||||
image_width, image_height = image_size.unbind(1)
|
||||
image_width = image_width.view(-1, 1) # (N, 1)
|
||||
image_height = image_height.view(-1, 1) # (N, 1)
|
||||
|
||||
ndc_z = ndc_points[..., 2]
|
||||
screen_x = (image_width - 1.0) / 2.0 * (1.0 - ndc_points[..., 0])
|
||||
screen_y = (image_height - 1.0) / 2.0 * (1.0 - ndc_points[..., 1])
|
||||
|
||||
return torch.stack((screen_x, screen_y, ndc_z), dim=2)
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
@@ -206,21 +269,56 @@ class CamerasBase(TensorProperties):
|
||||
return super().clone(other)
|
||||
|
||||
|
||||
########################
|
||||
# Specific camera classes
|
||||
########################
|
||||
############################################################
|
||||
# Field of View Camera Classes #
|
||||
############################################################
|
||||
|
||||
|
||||
class OpenGLPerspectiveCameras(CamerasBase):
|
||||
def OpenGLPerspectiveCameras(
|
||||
znear=1.0,
|
||||
zfar=100.0,
|
||||
aspect_ratio=1.0,
|
||||
fov=60.0,
|
||||
degrees: bool = True,
|
||||
R=r,
|
||||
T=t,
|
||||
device="cpu",
|
||||
):
|
||||
"""
|
||||
OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead.
|
||||
Preserving OpenGLPerspectiveCameras for backward compatibility.
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"""OpenGLPerspectiveCameras is deprecated,
|
||||
Use FoVPerspectiveCameras instead.
|
||||
OpenGLPerspectiveCameras will be removed in future releases.""",
|
||||
PendingDeprecationWarning,
|
||||
)
|
||||
|
||||
return FoVPerspectiveCameras(
|
||||
znear=znear,
|
||||
zfar=zfar,
|
||||
aspect_ratio=aspect_ratio,
|
||||
fov=fov,
|
||||
degrees=degrees,
|
||||
R=R,
|
||||
T=T,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
class FoVPerspectiveCameras(CamerasBase):
|
||||
"""
|
||||
A class which stores a batch of parameters to generate a batch of
|
||||
projection matrices using the OpenGL convention for a perspective camera.
|
||||
projection matrices by specifiying the field of view.
|
||||
The definition of the parameters follow the OpenGL perspective camera.
|
||||
|
||||
The extrinsics of the camera (R and T matrices) can also be set in the
|
||||
initializer or passed in to `get_full_projection_transform` to get
|
||||
the full transformation from world -> screen.
|
||||
the full transformation from world -> ndc.
|
||||
|
||||
The `transform_points` method calculates the full world -> screen transform
|
||||
The `transform_points` method calculates the full world -> ndc transform
|
||||
and then applies it to the input points.
|
||||
|
||||
The transforms can also be returned separately as Transform3d objects.
|
||||
@@ -267,8 +365,11 @@ class OpenGLPerspectiveCameras(CamerasBase):
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Calculate the OpenGL perpective projection matrix with a symmetric
|
||||
Calculate the perpective projection matrix with a symmetric
|
||||
viewing frustrum. Use column major order.
|
||||
The viewing frustrum will be projected into ndc, s.t.
|
||||
(max_x, max_y) -> (+1, +1)
|
||||
(min_x, min_y) -> (-1, -1)
|
||||
|
||||
Args:
|
||||
**kwargs: parameters for the projection can be passed in as keyword
|
||||
@@ -276,14 +377,14 @@ class OpenGLPerspectiveCameras(CamerasBase):
|
||||
|
||||
Return:
|
||||
P: a Transform3d object which represents a batch of projection
|
||||
matrices of shape (N, 3, 3)
|
||||
matrices of shape (N, 4, 4)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
f1 = -(far + near)/(far−near)
|
||||
f2 = -2*far*near/(far-near)
|
||||
h1 = (top + bottom)/(top - bottom)
|
||||
w1 = (right + left)/(right - left)
|
||||
h1 = (max_y + min_y)/(max_y - min_y)
|
||||
w1 = (max_x + min_x)/(max_x - min_x)
|
||||
tanhalffov = tan((fov/2))
|
||||
s1 = 1/tanhalffov
|
||||
s2 = 1/(tanhalffov * (aspect_ratio))
|
||||
@@ -310,10 +411,10 @@ class OpenGLPerspectiveCameras(CamerasBase):
|
||||
if not torch.is_tensor(fov):
|
||||
fov = torch.tensor(fov, device=self.device)
|
||||
tanHalfFov = torch.tan((fov / 2))
|
||||
top = tanHalfFov * znear
|
||||
bottom = -top
|
||||
right = top * aspect_ratio
|
||||
left = -right
|
||||
max_y = tanHalfFov * znear
|
||||
min_y = -max_y
|
||||
max_x = max_y * aspect_ratio
|
||||
min_x = -max_x
|
||||
|
||||
# NOTE: In OpenGL the projection matrix changes the handedness of the
|
||||
# coordinate frame. i.e the NDC space postive z direction is the
|
||||
@@ -323,28 +424,19 @@ class OpenGLPerspectiveCameras(CamerasBase):
|
||||
# so the so the z sign is 1.0.
|
||||
z_sign = 1.0
|
||||
|
||||
P[:, 0, 0] = 2.0 * znear / (right - left)
|
||||
P[:, 1, 1] = 2.0 * znear / (top - bottom)
|
||||
P[:, 0, 2] = (right + left) / (right - left)
|
||||
P[:, 1, 2] = (top + bottom) / (top - bottom)
|
||||
P[:, 0, 0] = 2.0 * znear / (max_x - min_x)
|
||||
P[:, 1, 1] = 2.0 * znear / (max_y - min_y)
|
||||
P[:, 0, 2] = (max_x + min_x) / (max_x - min_x)
|
||||
P[:, 1, 2] = (max_y + min_y) / (max_y - min_y)
|
||||
P[:, 3, 2] = z_sign * ones
|
||||
|
||||
# NOTE: This part of the matrix is for z renormalization in OpenGL
|
||||
# which maps the z to [-1, 1]. This won't work yet as the torch3d
|
||||
# rasterizer ignores faces which have z < 0.
|
||||
# P[:, 2, 2] = z_sign * (far + near) / (far - near)
|
||||
# P[:, 2, 3] = -2.0 * far * near / (far - near)
|
||||
# P[:, 3, 2] = z_sign * torch.ones((N))
|
||||
|
||||
# NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
|
||||
# is at the near clipping plane and z = 1 when the point is at the far
|
||||
# clipping plane. This replaces the OpenGL z normalization to [-1, 1]
|
||||
# until rasterization is changed to clip at z = -1.
|
||||
# clipping plane.
|
||||
P[:, 2, 2] = z_sign * zfar / (zfar - znear)
|
||||
P[:, 2, 3] = -(zfar * znear) / (zfar - znear)
|
||||
|
||||
# OpenGL uses column vectors so need to transpose the projection matrix
|
||||
# as torch3d uses row vectors.
|
||||
# Transpose the projection matrix as PyTorch3d transforms use row vectors.
|
||||
transform = Transform3d(device=self.device)
|
||||
transform._matrix = P.transpose(1, 2).contiguous()
|
||||
return transform
|
||||
@@ -357,7 +449,7 @@ class OpenGLPerspectiveCameras(CamerasBase):
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
""">!
|
||||
OpenGL cameras further allow for passing depth in world units
|
||||
FoV cameras further allow for passing depth in world units
|
||||
(`scaled_depth_input=False`) or in the [0, 1]-normalized units
|
||||
(`scaled_depth_input=True`)
|
||||
|
||||
@@ -367,11 +459,11 @@ class OpenGLPerspectiveCameras(CamerasBase):
|
||||
the world units.
|
||||
"""
|
||||
|
||||
# obtain the relevant transformation to screen
|
||||
# obtain the relevant transformation to ndc
|
||||
if world_coordinates:
|
||||
to_screen_transform = self.get_full_projection_transform()
|
||||
to_ndc_transform = self.get_full_projection_transform()
|
||||
else:
|
||||
to_screen_transform = self.get_projection_transform()
|
||||
to_ndc_transform = self.get_projection_transform()
|
||||
|
||||
if scaled_depth_input:
|
||||
# the input is scaled depth, so we don't have to do anything
|
||||
@@ -390,45 +482,84 @@ class OpenGLPerspectiveCameras(CamerasBase):
|
||||
xy_sdepth = torch.cat((xy_depth[..., 0:2], sdepth), dim=-1)
|
||||
|
||||
# unproject with inverse of the projection
|
||||
unprojection_transform = to_screen_transform.inverse()
|
||||
unprojection_transform = to_ndc_transform.inverse()
|
||||
return unprojection_transform.transform_points(xy_sdepth)
|
||||
|
||||
|
||||
class OpenGLOrthographicCameras(CamerasBase):
|
||||
def OpenGLOrthographicCameras(
|
||||
znear=1.0,
|
||||
zfar=100.0,
|
||||
top=1.0,
|
||||
bottom=-1.0,
|
||||
left=-1.0,
|
||||
right=1.0,
|
||||
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
|
||||
R=r,
|
||||
T=t,
|
||||
device="cpu",
|
||||
):
|
||||
"""
|
||||
OpenGLOrthographicCameras has been DEPRECATED. Use FoVOrthographicCameras instead.
|
||||
Preserving OpenGLOrthographicCameras for backward compatibility.
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"""OpenGLOrthographicCameras is deprecated,
|
||||
Use FoVOrthographicCameras instead.
|
||||
OpenGLOrthographicCameras will be removed in future releases.""",
|
||||
PendingDeprecationWarning,
|
||||
)
|
||||
|
||||
return FoVOrthographicCameras(
|
||||
znear=znear,
|
||||
zfar=zfar,
|
||||
max_y=top,
|
||||
min_y=bottom,
|
||||
max_x=right,
|
||||
min_x=left,
|
||||
scale_xyz=scale_xyz,
|
||||
R=R,
|
||||
T=T,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
class FoVOrthographicCameras(CamerasBase):
|
||||
"""
|
||||
A class which stores a batch of parameters to generate a batch of
|
||||
transformation matrices using the OpenGL convention for orthographic camera.
|
||||
projection matrices by specifiying the field of view.
|
||||
The definition of the parameters follow the OpenGL orthographic camera.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
znear=1.0,
|
||||
zfar=100.0,
|
||||
top=1.0,
|
||||
bottom=-1.0,
|
||||
left=-1.0,
|
||||
right=1.0,
|
||||
max_y=1.0,
|
||||
min_y=-1.0,
|
||||
max_x=1.0,
|
||||
min_x=-1.0,
|
||||
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
|
||||
R=r,
|
||||
T=t,
|
||||
device="cpu",
|
||||
):
|
||||
"""
|
||||
__init__(self, znear, zfar, top, bottom, left, right, scale_xyz, R, T, device) -> None # noqa
|
||||
__init__(self, znear, zfar, max_y, min_y, max_x, min_x, scale_xyz, R, T, device) -> None # noqa
|
||||
|
||||
Args:
|
||||
znear: near clipping plane of the view frustrum.
|
||||
zfar: far clipping plane of the view frustrum.
|
||||
top: position of the top of the screen.
|
||||
bottom: position of the bottom of the screen.
|
||||
left: position of the left of the screen.
|
||||
right: position of the right of the screen.
|
||||
max_y: maximum y coordinate of the frustrum.
|
||||
min_y: minimum y coordinate of the frustrum.
|
||||
max_x: maximum x coordinate of the frustrum.
|
||||
min_x: minumum x coordinage of the frustrum
|
||||
scale_xyz: scale factors for each axis of shape (N, 3).
|
||||
R: Rotation matrix of shape (N, 3, 3).
|
||||
T: Translation of shape (N, 3).
|
||||
device: torch.device or string.
|
||||
|
||||
Only need to set left, right, top, bottom for viewing frustrums
|
||||
Only need to set min_x, max_x, min_y, max_y for viewing frustrums
|
||||
which are non symmetric about the origin.
|
||||
"""
|
||||
# The initializer formats all inputs to torch tensors and broadcasts
|
||||
@@ -437,10 +568,10 @@ class OpenGLOrthographicCameras(CamerasBase):
|
||||
device=device,
|
||||
znear=znear,
|
||||
zfar=zfar,
|
||||
top=top,
|
||||
bottom=bottom,
|
||||
left=left,
|
||||
right=right,
|
||||
max_y=max_y,
|
||||
min_y=min_y,
|
||||
max_x=max_x,
|
||||
min_x=min_x,
|
||||
scale_xyz=scale_xyz,
|
||||
R=R,
|
||||
T=T,
|
||||
@@ -448,7 +579,7 @@ class OpenGLOrthographicCameras(CamerasBase):
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Calculate the OpenGL orthographic projection matrix.
|
||||
Calculate the orthographic projection matrix.
|
||||
Use column major order.
|
||||
|
||||
Args:
|
||||
@@ -456,16 +587,16 @@ class OpenGLOrthographicCameras(CamerasBase):
|
||||
override the default values set in __init__.
|
||||
Return:
|
||||
P: a Transform3d object which represents a batch of projection
|
||||
matrices of shape (N, 3, 3)
|
||||
matrices of shape (N, 4, 4)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
scale_x = 2/(right - left)
|
||||
scale_y = 2/(top - bottom)
|
||||
scale_z = 2/(far-near)
|
||||
mid_x = (right + left)/(right - left)
|
||||
mix_y = (top + bottom)/(top - bottom)
|
||||
mid_z = (far + near)/(far−near)
|
||||
scale_x = 2 / (max_x - min_x)
|
||||
scale_y = 2 / (max_y - min_y)
|
||||
scale_z = 2 / (far-near)
|
||||
mid_x = (max_x + min_x) / (max_x - min_x)
|
||||
mix_y = (max_y + min_y) / (max_y - min_y)
|
||||
mid_z = (far + near) / (far−near)
|
||||
|
||||
P = [
|
||||
[scale_x, 0, 0, -mid_x],
|
||||
@@ -476,10 +607,10 @@ class OpenGLOrthographicCameras(CamerasBase):
|
||||
"""
|
||||
znear = kwargs.get("znear", self.znear) # pyre-ignore[16]
|
||||
zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16]
|
||||
left = kwargs.get("left", self.left) # pyre-ignore[16]
|
||||
right = kwargs.get("right", self.right) # pyre-ignore[16]
|
||||
top = kwargs.get("top", self.top) # pyre-ignore[16]
|
||||
bottom = kwargs.get("bottom", self.bottom) # pyre-ignore[16]
|
||||
max_x = kwargs.get("max_x", self.max_x) # pyre-ignore[16]
|
||||
min_x = kwargs.get("min_x", self.min_x) # pyre-ignore[16]
|
||||
max_y = kwargs.get("max_y", self.max_y) # pyre-ignore[16]
|
||||
min_y = kwargs.get("min_y", self.min_y) # pyre-ignore[16]
|
||||
scale_xyz = kwargs.get("scale_xyz", self.scale_xyz) # pyre-ignore[16]
|
||||
|
||||
P = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device)
|
||||
@@ -489,10 +620,10 @@ class OpenGLOrthographicCameras(CamerasBase):
|
||||
# right handed coordinate system throughout.
|
||||
z_sign = +1.0
|
||||
|
||||
P[:, 0, 0] = (2.0 / (right - left)) * scale_xyz[:, 0]
|
||||
P[:, 1, 1] = (2.0 / (top - bottom)) * scale_xyz[:, 1]
|
||||
P[:, 0, 3] = -(right + left) / (right - left)
|
||||
P[:, 1, 3] = -(top + bottom) / (top - bottom)
|
||||
P[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
|
||||
P[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
|
||||
P[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
|
||||
P[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
|
||||
P[:, 3, 3] = ones
|
||||
|
||||
# NOTE: This maps the z coordinate to the range [0, 1] and replaces the
|
||||
@@ -500,12 +631,6 @@ class OpenGLOrthographicCameras(CamerasBase):
|
||||
P[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
|
||||
P[:, 2, 3] = -znear / (zfar - znear)
|
||||
|
||||
# NOTE: This part of the matrix is for z renormalization in OpenGL.
|
||||
# The z is mapped to the range [-1, 1] but this won't work yet in
|
||||
# pytorch3d as the rasterizer ignores faces which have z < 0.
|
||||
# P[:, 2, 2] = z_sign * (2.0 / (far - near)) * scale[:, 2]
|
||||
# P[:, 2, 3] = -(far + near) / (far - near)
|
||||
|
||||
transform = Transform3d(device=self.device)
|
||||
transform._matrix = P.transpose(1, 2).contiguous()
|
||||
return transform
|
||||
@@ -518,7 +643,7 @@ class OpenGLOrthographicCameras(CamerasBase):
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
""">!
|
||||
OpenGL cameras further allow for passing depth in world units
|
||||
FoV cameras further allow for passing depth in world units
|
||||
(`scaled_depth_input=False`) or in the [0, 1]-normalized units
|
||||
(`scaled_depth_input=True`)
|
||||
|
||||
@@ -529,9 +654,9 @@ class OpenGLOrthographicCameras(CamerasBase):
|
||||
"""
|
||||
|
||||
if world_coordinates:
|
||||
to_screen_transform = self.get_full_projection_transform(**kwargs.copy())
|
||||
to_ndc_transform = self.get_full_projection_transform(**kwargs.copy())
|
||||
else:
|
||||
to_screen_transform = self.get_projection_transform(**kwargs.copy())
|
||||
to_ndc_transform = self.get_projection_transform(**kwargs.copy())
|
||||
|
||||
if scaled_depth_input:
|
||||
# the input depth is already scaled
|
||||
@@ -547,22 +672,88 @@ class OpenGLOrthographicCameras(CamerasBase):
|
||||
# cat xy and scaled depth
|
||||
xy_sdepth = torch.cat((xy_depth[..., :2], scaled_depth), dim=-1)
|
||||
# finally invert the transform
|
||||
unprojection_transform = to_screen_transform.inverse()
|
||||
unprojection_transform = to_ndc_transform.inverse()
|
||||
return unprojection_transform.transform_points(xy_sdepth)
|
||||
|
||||
|
||||
class SfMPerspectiveCameras(CamerasBase):
|
||||
############################################################
|
||||
# MultiView Camera Classes #
|
||||
############################################################
|
||||
"""
|
||||
Note that the MultiView Cameras accept parameters in both
|
||||
screen and NDC space.
|
||||
If the user specifies `image_size` at construction time then
|
||||
we assume the parameters are in screen space.
|
||||
"""
|
||||
|
||||
|
||||
def SfMPerspectiveCameras(
|
||||
focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
|
||||
):
|
||||
"""
|
||||
SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead.
|
||||
Preserving SfMPerspectiveCameras for backward compatibility.
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"""SfMPerspectiveCameras is deprecated,
|
||||
Use PerspectiveCameras instead.
|
||||
SfMPerspectiveCameras will be removed in future releases.""",
|
||||
PendingDeprecationWarning,
|
||||
)
|
||||
|
||||
return PerspectiveCameras(
|
||||
focal_length=focal_length,
|
||||
principal_point=principal_point,
|
||||
R=R,
|
||||
T=T,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
class PerspectiveCameras(CamerasBase):
|
||||
"""
|
||||
A class which stores a batch of parameters to generate a batch of
|
||||
transformation matrices using the multi-view geometry convention for
|
||||
perspective camera.
|
||||
|
||||
Parameters for this camera can be specified in NDC or in screen space.
|
||||
If you wish to provide parameters in screen space, you NEED to provide
|
||||
the image_size = (imwidth, imheight).
|
||||
If you wish to provide parameters in NDC space, you should NOT provide
|
||||
image_size. Providing valid image_size will triger a screen space to
|
||||
NDC space transformation in the camera.
|
||||
|
||||
For example, here is how to define cameras on the two spaces.
|
||||
|
||||
.. code-block:: python
|
||||
# camera defined in screen space
|
||||
cameras = PerspectiveCameras(
|
||||
focal_length=((22.0, 15.0),), # (fx_screen, fy_screen)
|
||||
principal_point=((192.0, 128.0),), # (px_screen, py_screen)
|
||||
image_size=((256, 256),), # (imwidth, imheight)
|
||||
)
|
||||
|
||||
# the equivalent camera defined in NDC space
|
||||
cameras = PerspectiveCameras(
|
||||
focal_length=((0.17875, 0.11718),), # fx = fx_screen / half_imwidth,
|
||||
# fy = fy_screen / half_imheight
|
||||
principal_point=((-0.5, 0),), # px = - (px_screen - half_imwidth) / half_imwidth,
|
||||
# py = - (py_screen - half_imheight) / half_imheight
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
|
||||
self,
|
||||
focal_length=1.0,
|
||||
principal_point=((0.0, 0.0),),
|
||||
R=r,
|
||||
T=t,
|
||||
device="cpu",
|
||||
image_size=((-1, -1),),
|
||||
):
|
||||
"""
|
||||
__init__(self, focal_length, principal_point, R, T, device) -> None
|
||||
__init__(self, focal_length, principal_point, R, T, device, image_size) -> None
|
||||
|
||||
Args:
|
||||
focal_length: Focal length of the camera in world units.
|
||||
@@ -574,6 +765,11 @@ class SfMPerspectiveCameras(CamerasBase):
|
||||
R: Rotation matrix of shape (N, 3, 3)
|
||||
T: Translation matrix of shape (N, 3)
|
||||
device: torch.device or string
|
||||
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
|
||||
is provided, the camera parameters are assumed to be in screen
|
||||
space. They will be converted to NDC space.
|
||||
If image_size is not provided, the parameters are assumed to
|
||||
be in NDC space.
|
||||
"""
|
||||
# The initializer formats all inputs to torch tensors and broadcasts
|
||||
# all the inputs to have the same batch dimension where necessary.
|
||||
@@ -583,6 +779,7 @@ class SfMPerspectiveCameras(CamerasBase):
|
||||
principal_point=principal_point,
|
||||
R=R,
|
||||
T=T,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
@@ -615,9 +812,20 @@ class SfMPerspectiveCameras(CamerasBase):
|
||||
principal_point = kwargs.get("principal_point", self.principal_point)
|
||||
# pyre-ignore[16]
|
||||
focal_length = kwargs.get("focal_length", self.focal_length)
|
||||
# pyre-ignore[16]
|
||||
image_size = kwargs.get("image_size", self.image_size)
|
||||
|
||||
# if imwidth > 0, parameters are in screen space
|
||||
in_screen = image_size[0][0] > 0
|
||||
image_size = image_size if in_screen else None
|
||||
|
||||
P = _get_sfm_calibration_matrix(
|
||||
self._N, self.device, focal_length, principal_point, False
|
||||
self._N,
|
||||
self.device,
|
||||
focal_length,
|
||||
principal_point,
|
||||
orthographic=False,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
transform = Transform3d(device=self.device)
|
||||
@@ -628,29 +836,83 @@ class SfMPerspectiveCameras(CamerasBase):
|
||||
self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs
|
||||
) -> torch.Tensor:
|
||||
if world_coordinates:
|
||||
to_screen_transform = self.get_full_projection_transform(**kwargs)
|
||||
to_ndc_transform = self.get_full_projection_transform(**kwargs)
|
||||
else:
|
||||
to_screen_transform = self.get_projection_transform(**kwargs)
|
||||
to_ndc_transform = self.get_projection_transform(**kwargs)
|
||||
|
||||
unprojection_transform = to_screen_transform.inverse()
|
||||
unprojection_transform = to_ndc_transform.inverse()
|
||||
xy_inv_depth = torch.cat(
|
||||
(xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]), dim=-1 # type: ignore
|
||||
)
|
||||
return unprojection_transform.transform_points(xy_inv_depth)
|
||||
|
||||
|
||||
class SfMOrthographicCameras(CamerasBase):
|
||||
def SfMOrthographicCameras(
|
||||
focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
|
||||
):
|
||||
"""
|
||||
SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead.
|
||||
Preserving SfMOrthographicCameras for backward compatibility.
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"""SfMOrthographicCameras is deprecated,
|
||||
Use OrthographicCameras instead.
|
||||
SfMOrthographicCameras will be removed in future releases.""",
|
||||
PendingDeprecationWarning,
|
||||
)
|
||||
|
||||
return OrthographicCameras(
|
||||
focal_length=focal_length,
|
||||
principal_point=principal_point,
|
||||
R=R,
|
||||
T=T,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
class OrthographicCameras(CamerasBase):
|
||||
"""
|
||||
A class which stores a batch of parameters to generate a batch of
|
||||
transformation matrices using the multi-view geometry convention for
|
||||
orthographic camera.
|
||||
|
||||
Parameters for this camera can be specified in NDC or in screen space.
|
||||
If you wish to provide parameters in screen space, you NEED to provide
|
||||
the image_size = (imwidth, imheight).
|
||||
If you wish to provide parameters in NDC space, you should NOT provide
|
||||
image_size. Providing valid image_size will triger a screen space to
|
||||
NDC space transformation in the camera.
|
||||
|
||||
For example, here is how to define cameras on the two spaces.
|
||||
|
||||
.. code-block:: python
|
||||
# camera defined in screen space
|
||||
cameras = OrthographicCameras(
|
||||
focal_length=((22.0, 15.0),), # (fx, fy)
|
||||
principal_point=((192.0, 128.0),), # (px, py)
|
||||
image_size=((256, 256),), # (imwidth, imheight)
|
||||
)
|
||||
|
||||
# the equivalent camera defined in NDC space
|
||||
cameras = OrthographicCameras(
|
||||
focal_length=((0.17875, 0.11718),), # := (fx / half_imwidth, fy / half_imheight)
|
||||
principal_point=((-0.5, 0),), # := (- (px - half_imwidth) / half_imwidth,
|
||||
- (py - half_imheight) / half_imheight)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
|
||||
self,
|
||||
focal_length=1.0,
|
||||
principal_point=((0.0, 0.0),),
|
||||
R=r,
|
||||
T=t,
|
||||
device="cpu",
|
||||
image_size=((-1, -1),),
|
||||
):
|
||||
"""
|
||||
__init__(self, focal_length, principal_point, R, T, device) -> None
|
||||
__init__(self, focal_length, principal_point, R, T, device, image_size) -> None
|
||||
|
||||
Args:
|
||||
focal_length: Focal length of the camera in world units.
|
||||
@@ -662,6 +924,11 @@ class SfMOrthographicCameras(CamerasBase):
|
||||
R: Rotation matrix of shape (N, 3, 3)
|
||||
T: Translation matrix of shape (N, 3)
|
||||
device: torch.device or string
|
||||
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
|
||||
is provided, the camera parameters are assumed to be in screen
|
||||
space. They will be converted to NDC space.
|
||||
If image_size is not provided, the parameters are assumed to
|
||||
be in NDC space.
|
||||
"""
|
||||
# The initializer formats all inputs to torch tensors and broadcasts
|
||||
# all the inputs to have the same batch dimension where necessary.
|
||||
@@ -671,6 +938,7 @@ class SfMOrthographicCameras(CamerasBase):
|
||||
principal_point=principal_point,
|
||||
R=R,
|
||||
T=T,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
@@ -703,9 +971,20 @@ class SfMOrthographicCameras(CamerasBase):
|
||||
principal_point = kwargs.get("principal_point", self.principal_point)
|
||||
# pyre-ignore[16]
|
||||
focal_length = kwargs.get("focal_length", self.focal_length)
|
||||
# pyre-ignore[16]
|
||||
image_size = kwargs.get("image_size", self.image_size)
|
||||
|
||||
# if imwidth > 0, parameters are in screen space
|
||||
in_screen = image_size[0][0] > 0
|
||||
image_size = image_size if in_screen else None
|
||||
|
||||
P = _get_sfm_calibration_matrix(
|
||||
self._N, self.device, focal_length, principal_point, True
|
||||
self._N,
|
||||
self.device,
|
||||
focal_length,
|
||||
principal_point,
|
||||
orthographic=True,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
transform = Transform3d(device=self.device)
|
||||
@@ -716,17 +995,26 @@ class SfMOrthographicCameras(CamerasBase):
|
||||
self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs
|
||||
) -> torch.Tensor:
|
||||
if world_coordinates:
|
||||
to_screen_transform = self.get_full_projection_transform(**kwargs)
|
||||
to_ndc_transform = self.get_full_projection_transform(**kwargs)
|
||||
else:
|
||||
to_screen_transform = self.get_projection_transform(**kwargs)
|
||||
to_ndc_transform = self.get_projection_transform(**kwargs)
|
||||
|
||||
unprojection_transform = to_screen_transform.inverse()
|
||||
unprojection_transform = to_ndc_transform.inverse()
|
||||
return unprojection_transform.transform_points(xy_depth)
|
||||
|
||||
|
||||
# SfMCameras helper
|
||||
################################################
|
||||
# Helper functions for cameras #
|
||||
################################################
|
||||
|
||||
|
||||
def _get_sfm_calibration_matrix(
|
||||
N, device, focal_length, principal_point, orthographic: bool
|
||||
N,
|
||||
device,
|
||||
focal_length,
|
||||
principal_point,
|
||||
orthographic: bool = False,
|
||||
image_size=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a calibration matrix of a perspective/orthograpic camera.
|
||||
@@ -736,6 +1024,10 @@ def _get_sfm_calibration_matrix(
|
||||
focal_length: Focal length of the camera in world units.
|
||||
principal_point: xy coordinates of the center of
|
||||
the principal point of the camera in pixels.
|
||||
orthographic: Boolean specifying if the camera is orthographic or not
|
||||
image_size: (Optional) Specifying the image_size = (imwidth, imheight).
|
||||
If not None, the camera parameters are assumed to be in screen space
|
||||
and are transformed to NDC space.
|
||||
|
||||
The calibration matrix `K` is set up as follows:
|
||||
|
||||
@@ -769,7 +1061,7 @@ def _get_sfm_calibration_matrix(
|
||||
if not torch.is_tensor(focal_length):
|
||||
focal_length = torch.tensor(focal_length, device=device)
|
||||
|
||||
if len(focal_length.shape) in (0, 1) or focal_length.shape[1] == 1:
|
||||
if focal_length.ndim in (0, 1) or focal_length.shape[1] == 1:
|
||||
fx = fy = focal_length
|
||||
else:
|
||||
fx, fy = focal_length.unbind(1)
|
||||
@@ -779,6 +1071,22 @@ def _get_sfm_calibration_matrix(
|
||||
|
||||
px, py = principal_point.unbind(1)
|
||||
|
||||
if image_size is not None:
|
||||
if not torch.is_tensor(image_size):
|
||||
image_size = torch.tensor(image_size, device=device)
|
||||
imwidth, imheight = image_size.unbind(1)
|
||||
# make sure imwidth, imheight are valid (>0)
|
||||
if (imwidth < 1).any() or (imheight < 1).any():
|
||||
raise ValueError(
|
||||
"Camera parameters provided in screen space. Image width or height invalid."
|
||||
)
|
||||
half_imwidth = imwidth / 2.0
|
||||
half_imheight = imheight / 2.0
|
||||
fx = fx / half_imwidth
|
||||
fy = fy / half_imheight
|
||||
px = -(px - half_imwidth) / half_imwidth
|
||||
py = -(py - half_imheight) / half_imheight
|
||||
|
||||
K = fx.new_zeros(N, 4, 4)
|
||||
K[:, 0, 0] = fx
|
||||
K[:, 1, 1] = fy
|
||||
|
||||
Reference in New Issue
Block a user