NDC/screen cameras API fix, compatibility with renderer

Summary:
API fix for NDC/screen cameras and compatibility with PyTorch3D renderers.

With this new fix:
* Users can define cameras and `transform_points` under any coordinate system conventions. The transformation applies the camera K and RT to the input points, not regarding for PyTorch3D conventions. So this makes cameras completely independent from PyTorch3D renderer.

* Cameras can be defined either in NDC space or screen space. For existing ones, FoV cameras are in NDC space. Perspective/Orthographic can be defined in NDC or screen space.

* The interface with PyTorch3D renderers happens through `transform_points_ndc` which transforms points to the NDC space and assumes that input points are provided according to PyTorch3D conventions.

* Similarly, `transform_points_screen` transforms points to screen space and again assumes that input points are under PyTorch3D conventions.

* For Orthographic/Perspective cameras, if they are defined in screen space, the `get_ndc_camera_transform` allows points to be converted to NDC for use for the renderers.

Reviewed By: nikhilaravi

Differential Revision: D26932657

fbshipit-source-id: 1a964e3e7caa54d10c792cf39c4d527ba2fb2e79
This commit is contained in:
Georgia Gkioxari 2021-08-02 01:00:03 -07:00 committed by Facebook GitHub Bot
parent 9a14f54e8b
commit 0c32f094af
6 changed files with 503 additions and 223 deletions

View File

@ -13,7 +13,8 @@ This is the system the object/scene lives - the world.
* **Camera view coordinate system**
This is the system that has its origin on the image plane and the `Z`-axis perpendicular to the image plane. In PyTorch3D, we assume that `+X` points left, and `+Y` points up and `+Z` points out from the image plane. The transformation from world to view happens after applying a rotation (`R`) and translation (`T`).
* **NDC coordinate system**
This is the normalized coordinate system that confines in a volume the rendered part of the object/scene. Also known as view volume. Under the PyTorch3D convention, `(+1, +1, znear)` is the top left near corner, and `(-1, -1, zfar)` is the bottom right far corner of the volume. The transformation from view to NDC happens after applying the camera projection matrix (`P`).
This is the normalized coordinate system that confines in a volume the rendered part of the object/scene. Also known as view volume. Under the PyTorch3D convention, `(+1, +1, znear)` is the top left near corner, and `(-1, -1, zfar)` is the bottom right far corner of the volume. For non-square volumes, the side of the volume in `XY` with the smallest length ranges from `[-1, 1]` while the larger side from `[-s, s]`, where `s` is the aspect ratio and `s > 1` (larger divided by smaller side).
The transformation from view to NDC happens after applying the camera projection matrix (`P`).
* **Screen coordinate system**
This is another representation of the view volume with the `XY` coordinates defined in pixel space instead of a normalized space.
@ -22,47 +23,78 @@ An illustration of the 4 coordinate systems is shown below
## Defining Cameras in PyTorch3D
Cameras in PyTorch3D transform an object/scene from world to NDC by first transforming the object/scene to view (via transforms `R` and `T`) and then projecting the 3D object/scene to NDC (via the projection matrix `P`, else known as camera matrix). Thus, the camera parameters in `P` are assumed to be in NDC space. If the user has camera parameters in screen space, which is a common use case, the parameters should transformed to NDC (see below for an example)
Cameras in PyTorch3D transform an object/scene from world to view by first transforming the object/scene to view (via transforms `R` and `T`) and then projecting the 3D object/scene to a normalized space via the projection matrix `P = K[R | T]`, where `K` is the intrinsic matrix. The camera parameters in `K` define the normalized space. If users define the camera parameters in NDC space, then the transform projects points to NDC. If the camera parameters are defined in screen space, the transformed points are in screen space.
We describe the camera types in PyTorch3D and the convention for the camera parameters provided at construction time.
Note that the base `CamerasBase` class makes no assumptions about the coordinate systems. All the above transforms are geometric transforms defined purely by `R`, `T` and `K`. This means that users can define cameras in any coordinate system and for any transforms. The method `transform_points` will apply `K` , `R` and `T` to the input points as a simple matrix transformation. However, if users wish to use cameras with the PyTorch3D renderer, they need to abide to PyTorch3D's coordinate system assumptions (read below).
We provide instantiations of common camera types in PyTorch3D and how users can flexibly define the projection space below.
## Interfacing with the PyTorch3D Renderer
The PyTorch3D renderer for both meshes and point clouds assumes that the camera transformed points, meaning the points passed as input to the rasterizer, are in PyTorch3D's NDC space. So to get the expected rendering outcome, users need to make sure that their 3D input data and cameras abide by these PyTorch3D coordinate system assumptions. The PyTorch3D coordinate system assumes `+X:left`, `+Y: up` and `+Z: from us to scene` (right-handed) . Confusions regarding coordinate systems are common so we advise that you spend some time understanding your data and the coordinate system they live in and transform them accordingly before using the PyTorch3D renderer.
Examples of cameras and how they interface with the PyTorch3D renderer can be found in our tutorials.
### Camera Types
All cameras inherit from `CamerasBase` which is a base class for all cameras. PyTorch3D provides four different camera types. The `CamerasBase` defines methods that are common to all camera models:
* `get_camera_center` that returns the optical center of the camera in world coordinates
* `get_world_to_view_transform` which returns a 3D transform from world coordinates to the camera view coordinates (R, T)
* `get_full_projection_transform` which composes the projection transform (P) with the world-to-view transform (R, T)
* `get_world_to_view_transform` which returns a 3D transform from world coordinates to the camera view coordinates `(R, T)`
* `get_full_projection_transform` which composes the projection transform (`K`) with the world-to-view transform `(R, T)`
* `transform_points` which takes a set of input points in world coordinates and projects to NDC coordinates ranging from [-1, -1, znear] to [+1, +1, zfar].
* `get_ndc_camera_transform` which defines the conversion to PyTorch3D's NDC space and is called when interfacing with the PyTorch3D renderer. If the camera is defined in NDC space, then the identity transform is returned. If the cameras is defined in screen space, the conversion from screen to NDC is returned. If users define their own camera in screen space, they need to think of the screen to NDC conversion. We provide examples for the `PerspectiveCameras` and `OrthographicCameras`.
* `transform_points_ndc` which takes a set of points in world coordinates and projects them to PyTorch3D's NDC space
* `transform_points_screen` which takes a set of input points in world coordinates and projects them to the screen coordinates ranging from [0, 0, znear] to [W-1, H-1, zfar]
Users can easily customize their own cameras. For each new camera, users should implement the `get_projection_transform` routine that returns the mapping `P` from camera view coordinates to NDC coordinates.
#### FoVPerspectiveCameras, FoVOrthographicCameras
These two cameras follow the OpenGL convention for perspective and orthographic cameras respectively. The user provides the near `znear` and far `zfar` field which confines the view volume in the `Z` axis. The view volume in the `XY` plane is defined by field of view angle (`fov`) in the case of `FoVPerspectiveCameras` and by `min_x, min_y, max_x, max_y` in the case of `FoVOrthographicCameras`.
These cameras are by default in NDC space.
#### PerspectiveCameras, OrthographicCameras
These two cameras follow the Multi-View Geometry convention for cameras. The user provides the focal length (`fx`, `fy`) and the principal point (`px`, `py`). For example, `camera = PerspectiveCameras(focal_length=((fx, fy),), principal_point=((px, py),))`
As mentioned above, the focal length and principal point are used to convert a point `(X, Y, Z)` from view coordinates to NDC coordinates, as follows
The camera projection of a 3D point `(X, Y, Z)` in view coordinates to a point `(x, y, z)` in projection space (either NDC or screen) is
```
# for perspective
x_ndc = fx * X / Z + px
y_ndc = fy * Y / Z + py
z_ndc = 1 / Z
# for perspective camera
x = fx * X / Z + px
y = fy * Y / Z + py
z = 1 / Z
# for orthographic
x_ndc = fx * X + px
y_ndc = fy * Y + py
z_ndc = Z
# for orthographic camera
x = fx * X + px
y = fy * Y + py
z = Z
```
Commonly, users have access to the focal length (`fx_screen`, `fy_screen`) and the principal point (`px_screen`, `py_screen`) in screen space. In that case, to construct the camera the user needs to additionally provide the `image_size = ((image_width, image_height),)`. More precisely, `camera = PerspectiveCameras(focal_length=((fx_screen, fy_screen),), principal_point=((px_screen, py_screen),), image_size = ((image_width, image_height),))`. Internally, the camera parameters are converted from screen to NDC as follows:
The user can define the camera parameters in NDC or in screen space. Screen space camera parameters are common and for that case the user needs to set `in_ndc` to `False` and also provide the `image_size=(height, width)` of the screen, aka the image.
The `get_ndc_camera_transform` provides the transform from screen to NDC space in PyTorch3D. Note that the screen space assumes that the principal point is provided in the space with `+X left`, `+Y down` and origin at the top left corner of the image. To convert to NDC we need to account for the scaling of the normalized space as well as the change in `XY` direction.
Below are example of equivalent `PerspectiveCameras` instantiations in NDC and screen space, respectively.
```python
# NDC space camera
fcl_ndc = (1.2,)
prp_ndc = ((0.2, 0.5),)
cameras_ndc = PerspectiveCameras(focal_length=fcl_ndc, principal_point=prp_ndc)
# Screen space camera
image_size = ((128, 256),) # (h, w)
fcl_screen = (76.2,) # fcl_ndc * (min(image_size) - 1) / 2
prp_screen = ((114.8, 31.75), ) # (w - 1) / 2 - px_ndc * (min(image_size) - 1) / 2, (h - 1) / 2 - py_ndc * (min(image_size) - 1) / 2
cameras_screen = PerspectiveCameras(focal_length=fcl_screen, principal_point=prp_screen, in_ndc=False, image_size=image_size)
```
The relationship between screen and NDC specifications of a camera's `focal_length` and `principal_point` is given by the following equations, where `s = min(image_width, image_height)`.
The transformation of x and y coordinates between screen and NDC is exactly the same as for px and py.
```
fx = fx_screen * 2.0 / image_width
fy = fy_screen * 2.0 / image_height
fx_ndc = fx_screen * 2.0 / (s - 1)
fy_ndc = fy_screen * 2.0 / (s - 1)
px = - (px_screen - image_width / 2.0) * 2.0 / image_width
py = - (py_screen - image_height / 2.0) * 2.0/ image_height
px_ndc = - (px_screen - (image_width - 1) / 2.0) * 2.0 / (s - 1)
py_ndc = - (py_screen - (image_height - 1) / 2.0) * 2.0 / (s - 1)
```

View File

@ -6,7 +6,7 @@
import math
import warnings
from typing import Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple, Union, List
import numpy as np
import torch
@ -28,20 +28,20 @@ class CamerasBase(TensorProperties):
For cameras, there are four different coordinate systems (or spaces)
- World coordinate system: This is the system the object lives - the world.
- Camera view coordinate system: This is the system that has its origin on the image plane
- Camera view coordinate system: This is the system that has its origin on the camera
and the and the Z-axis perpendicular to the image plane.
In PyTorch3D, we assume that +X points left, and +Y points up and
+Z points out from the image plane.
The transformation from world -> view happens after applying a rotation (R)
The transformation from world --> view happens after applying a rotation (R)
and translation (T)
- NDC coordinate system: This is the normalized coordinate system that confines
in a volume the rendered part of the object or scene. Also known as view volume.
Given the PyTorch3D convention, (+1, +1, znear) is the top left near corner,
and (-1, -1, zfar) is the bottom right far corner of the volume.
The transformation from view -> NDC happens after applying the camera
projection matrix (P).
The transformation from view --> NDC happens after applying the camera
projection matrix (P) if defined in NDC space.
- Screen coordinate system: This is another representation of the view volume with
the XY coordinates defined in pixel space instead of a normalized space.
the XY coordinates defined in image space instead of a normalized space.
A better illustration of the coordinate systems can be found in
pytorch3d/docs/notes/cameras.md.
@ -54,17 +54,21 @@ class CamerasBase(TensorProperties):
- `get_full_projection_transform` which composes the projection
transform (P) with the world-to-view transform (R, T)
- `transform_points` which takes a set of input points in world coordinates and
projects to NDC coordinates ranging from [-1, -1, znear] to [+1, +1, zfar].
- `transform_points_screen` which takes a set of input points in world coordinates and
projects them to the screen coordinates ranging from
[0, 0, znear] to [W-1, H-1, zfar]
projects to the space the camera is defined in (NDC or screen)
- `get_ndc_camera_transform` which defines the transform from screen/NDC to
PyTorch3D's NDC space
- `transform_points_ndc` which takes a set of points in world coordinates and
projects them to PyTorch3D's NDC space
- `transform_points_screen` which takes a set of points in world coordinates and
projects them to screen space
For each new camera, one should implement the `get_projection_transform`
routine that returns the mapping from camera view coordinates to NDC coordinates.
routine that returns the mapping from camera view coordinates to camera
coordinates (NDC or screen).
Another useful function that is specific to each camera model is
`unproject_points` which sends points from NDC coordinates back to
camera view or world coordinates depending on the `world_coordinates`
`unproject_points` which sends points from camera coordinates (NDC or screen)
back to camera view or world coordinates depending on the `world_coordinates`
boolean argument of the function.
"""
@ -84,7 +88,7 @@ class CamerasBase(TensorProperties):
def unproject_points(self):
"""
Transform input points from NDC coordinates
Transform input points from camera coodinates (NDC or screen)
to the world / camera coordinates.
Each of the input points `xy_depth` of shape (..., 3) is
@ -181,8 +185,10 @@ class CamerasBase(TensorProperties):
def get_full_projection_transform(self, **kwargs) -> Transform3d:
"""
Return the full world-to-NDC transform composing the
world-to-view and view-to-NDC transforms.
Return the full world-to-camera transform composing the
world-to-view and view-to-camera transforms.
If camera is defined in NDC space, the projected points are in NDC space.
If camera is defined in screen space, the projected points are in screen space.
Args:
**kwargs: parameters for the projection transforms can be passed in
@ -200,14 +206,70 @@ class CamerasBase(TensorProperties):
self.R: torch.Tensor = kwargs.get("R", self.R) # pyre-ignore[16]
self.T: torch.Tensor = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
view_to_ndc_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_ndc_transform)
view_to_proj_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_proj_transform)
def transform_points(
self, points, eps: Optional[float] = None, **kwargs
) -> torch.Tensor:
"""
Transform input points from world to NDC space.
Transform input points from world to camera space with the
projection matrix defined by the camera.
For `CamerasBase.transform_points`, setting `eps > 0`
stabilizes gradients since it leads to avoiding division
by excessively low numbers for points close to the camera plane.
Args:
points: torch tensor of shape (..., 3).
eps: If eps!=None, the argument is used to clamp the
divisor in the homogeneous normalization of the points
transformed to the ndc space. Please see
`transforms.Transform3D.transform_points` for details.
For `CamerasBase.transform_points`, setting `eps > 0`
stabilizes gradients since it leads to avoiding division
by excessively low numbers for points close to the
camera plane.
Returns
new_points: transformed points with the same shape as the input.
"""
world_to_proj_transform = self.get_full_projection_transform(**kwargs)
return world_to_proj_transform.transform_points(points, eps=eps)
def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
"""
Returns the transform from camera projection space (screen or NDC) to NDC space.
For cameras that can be specified in screen space, this transform
allows points to be converted from screen to NDC space.
The default transform scales the points from [0, W-1]x[0, H-1] to [-1, 1].
This function should be modified per camera definitions if need be,
e.g. for Perspective/Orthographic cameras we provide a custom implementation.
This transform assumes PyTorch3D coordinate system conventions for
both the NDC space and the input points.
This transform interfaces with the PyTorch3D renderer which assumes
input points to the renderer to be in NDC space.
"""
if self.in_ndc():
return Transform3d(device=self.device, dtype=torch.float32)
else:
# For custom cameras which can be defined in screen space,
# users might might have to implement the screen to NDC transform based
# on the definition of the camera parameters.
# See PerspectiveCameras/OrthographicCameras for an example.
# We don't flip xy because we assume that world points are in PyTorch3D coodrinates
# and thus conversion from screen to ndc is a mere scaling from image to [-1, 1] scale.
return get_screen_to_ndc_transform(self, with_xyflip=False, **kwargs)
def transform_points_ndc(
self, points, eps: Optional[float] = None, **kwargs
) -> torch.Tensor:
"""
Transforms points from PyTorch3D world/camera space to NDC space.
Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up.
Output points are in NDC space: +X left, +Y up, origin at image center.
Args:
points: torch tensor of shape (..., 3).
@ -225,17 +287,22 @@ class CamerasBase(TensorProperties):
new_points: transformed points with the same shape as the input.
"""
world_to_ndc_transform = self.get_full_projection_transform(**kwargs)
if not self.in_ndc():
to_ndc_transform = self.get_ndc_camera_transform(**kwargs)
world_to_ndc_transform = world_to_ndc_transform.compose(to_ndc_transform)
return world_to_ndc_transform.transform_points(points, eps=eps)
def transform_points_screen(
self, points, image_size, eps: Optional[float] = None, **kwargs
self, points, eps: Optional[float] = None, **kwargs
) -> torch.Tensor:
"""
Transform input points from world to screen space.
Transforms points from PyTorch3D world/camera space to screen space.
Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up.
Output points are in screen space: +X right, +Y down, origin at top left corner.
Args:
points: torch tensor of shape (N, V, 3).
image_size: torch tensor of shape (N, 2)
points: torch tensor of shape (..., 3).
eps: If eps!=None, the argument is used to clamp the
divisor in the homogeneous normalization of the points
transformed to the ndc space. Please see
@ -249,25 +316,10 @@ class CamerasBase(TensorProperties):
Returns
new_points: transformed points with the same shape as the input.
"""
ndc_points = self.transform_points(points, eps=eps, **kwargs)
if not torch.is_tensor(image_size):
image_size = torch.tensor(
image_size, dtype=torch.int64, device=points.device
)
if (image_size < 1).any():
raise ValueError("Provided image size is invalid.")
image_width, image_height = image_size.unbind(1)
image_width = image_width.view(-1, 1) # (N, 1)
image_height = image_height.view(-1, 1) # (N, 1)
ndc_z = ndc_points[..., 2]
screen_x = (image_width - 1.0) / 2.0 * (1.0 - ndc_points[..., 0])
screen_y = (image_height - 1.0) / 2.0 * (1.0 - ndc_points[..., 1])
return torch.stack((screen_x, screen_y, ndc_z), dim=2)
points_ndc = self.transform_points_ndc(points, eps=eps, **kwargs)
return get_ndc_to_screen_transform(
self, with_xyflip=True, **kwargs
).transform_points(points_ndc, eps=eps)
def clone(self):
"""
@ -280,9 +332,23 @@ class CamerasBase(TensorProperties):
def is_perspective(self):
raise NotImplementedError()
def in_ndc(self):
"""
Specifies whether the camera is defined in NDC space
or in screen (image) space
"""
raise NotImplementedError()
def get_znear(self):
return self.znear if hasattr(self, "znear") else None
def get_image_size(self):
"""
Returns the image size, if provided, expected in the form of (height, width)
The image size is used for conversion of projected points to screen coordinates.
"""
return self.image_size if hasattr(self, "image_size") else None
############################################################
# Field of View Camera Classes #
@ -501,8 +567,9 @@ class FoVPerspectiveCameras(CamerasBase):
)
# Transpose the projection matrix as PyTorch3D transforms use row vectors.
transform = Transform3d(device=self.device)
transform._matrix = K.transpose(1, 2).contiguous()
transform = Transform3d(
matrix=K.transpose(1, 2).contiguous(), device=self.device
)
return transform
def unproject_points(
@ -552,6 +619,9 @@ class FoVPerspectiveCameras(CamerasBase):
def is_perspective(self):
return True
def in_ndc(self):
return True
def OpenGLOrthographicCameras(
znear=1.0,
@ -726,8 +796,9 @@ class FoVOrthographicCameras(CamerasBase):
kwargs.get("scale_xyz", self.scale_xyz),
)
transform = Transform3d(device=self.device)
transform._matrix = K.transpose(1, 2).contiguous()
transform = Transform3d(
matrix=K.transpose(1, 2).contiguous(), device=self.device
)
return transform
def unproject_points(
@ -773,15 +844,15 @@ class FoVOrthographicCameras(CamerasBase):
def is_perspective(self):
return False
def in_ndc(self):
return True
############################################################
# MultiView Camera Classes #
############################################################
"""
Note that the MultiView Cameras accept parameters in both
screen and NDC space.
If the user specifies `image_size` at construction time then
we assume the parameters are in screen space.
Note that the MultiView Cameras accept parameters in NDC space.
"""
@ -819,30 +890,8 @@ class PerspectiveCameras(CamerasBase):
transformation matrices using the multi-view geometry convention for
perspective camera.
Parameters for this camera can be specified in NDC or in screen space.
If you wish to provide parameters in screen space, you NEED to provide
the image_size = (imwidth, imheight).
If you wish to provide parameters in NDC space, you should NOT provide
image_size. Providing valid image_size will trigger a screen space to
NDC space transformation in the camera.
For example, here is how to define cameras on the two spaces.
.. code-block:: python
# camera defined in screen space
cameras = PerspectiveCameras(
focal_length=((22.0, 15.0),), # (fx_screen, fy_screen)
principal_point=((192.0, 128.0),), # (px_screen, py_screen)
image_size=((256, 256),), # (imwidth, imheight)
)
# the equivalent camera defined in NDC space
cameras = PerspectiveCameras(
focal_length=((0.17875, 0.11718),), # fx = fx_screen / half_imwidth,
# fy = fy_screen / half_imheight
principal_point=((-0.5, 0),), # px = - (px_screen - half_imwidth) / half_imwidth,
# py = - (py_screen - half_imheight) / half_imheight
)
Parameters for this camera are specified in NDC if `in_ndc` is set to True.
If parameters are specified in screen space, `in_ndc` must be set to False.
"""
def __init__(
@ -853,7 +902,8 @@ class PerspectiveCameras(CamerasBase):
T: torch.Tensor = _T,
K: Optional[torch.Tensor] = None,
device: Device = "cpu",
image_size=((-1, -1),),
in_ndc: bool = True,
image_size: Optional[Union[List, Tuple, torch.Tensor]] = None,
) -> None:
"""
@ -864,20 +914,20 @@ class PerspectiveCameras(CamerasBase):
principal_point: xy coordinates of the center of
the principal point of the camera in pixels.
A tensor of shape (N, 2).
in_ndc: True if camera parameters are specified in NDC.
If camera parameters are in screen space, it must
be set to False.
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
K: (optional) A calibration matrix of shape (N, 4, 4)
If provided, don't need focal_length, principal_point, image_size
If provided, don't need focal_length, principal_point
image_size: (height, width) of image size.
A tensor of shape (N, 2). Required for screen cameras.
device: torch.device or string
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
is provided, the camera parameters are assumed to be in screen
space. They will be converted to NDC space.
If image_size is not provided, the parameters are assumed to
be in NDC space.
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
kwargs = {"image_size": image_size} if image_size is not None else {}
super().__init__(
device=device,
focal_length=focal_length,
@ -885,8 +935,14 @@ class PerspectiveCameras(CamerasBase):
R=R,
T=T,
K=K,
image_size=image_size,
_in_ndc=in_ndc,
**kwargs, # pyre-ignore
)
if image_size is not None:
if (self.image_size < 1).any(): # pyre-ignore
raise ValueError("Image_size provided has invalid values")
else:
self.image_size = None
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
@ -920,40 +976,86 @@ class PerspectiveCameras(CamerasBase):
msg = "Expected K to have shape of (%r, 4, 4)"
raise ValueError(msg % (self._N))
else:
image_size = kwargs.get("image_size", self.image_size)
# if imwidth > 0, parameters are in screen space
image_size = image_size if image_size[0][0] > 0 else None
K = _get_sfm_calibration_matrix(
self._N,
self.device,
kwargs.get("focal_length", self.focal_length),
kwargs.get("principal_point", self.principal_point),
orthographic=False,
image_size=image_size,
)
transform = Transform3d(device=self.device)
transform._matrix = K.transpose(1, 2).contiguous()
transform = Transform3d(
matrix=K.transpose(1, 2).contiguous(), device=self.device
)
return transform
def unproject_points(
self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs
) -> torch.Tensor:
if world_coordinates:
to_ndc_transform = self.get_full_projection_transform(**kwargs)
to_camera_transform = self.get_full_projection_transform(**kwargs)
else:
to_ndc_transform = self.get_projection_transform(**kwargs)
to_camera_transform = self.get_projection_transform(**kwargs)
unprojection_transform = to_ndc_transform.inverse()
unprojection_transform = to_camera_transform.inverse()
xy_inv_depth = torch.cat(
(xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]), dim=-1 # type: ignore
)
return unprojection_transform.transform_points(xy_inv_depth)
def get_principal_point(self, **kwargs) -> torch.Tensor:
"""
Return the camera's principal point
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
"""
proj_mat = self.get_projection_transform(**kwargs).get_matrix()
return proj_mat[:, 2, :2]
def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
"""
Returns the transform from camera projection space (screen or NDC) to NDC space.
If the camera is defined already in NDC space, the transform is identity.
For cameras defined in screen space, we adjust the principal point computation
which is defined in the image space (commonly) and scale the points to NDC space.
Important: This transforms assumes PyTorch3D conventions for the input points,
i.e. +X left, +Y up.
"""
if self.in_ndc():
ndc_transform = Transform3d(device=self.device, dtype=torch.float32)
else:
# when cameras are defined in screen/image space, the principal point is
# provided in the (+X right, +Y down), aka image, coordinate system.
# Since input points are defined in the PyTorch3D system (+X left, +Y up),
# we need to adjust for the principal point transform.
pr_point_fix = torch.zeros(
(self._N, 4, 4), device=self.device, dtype=torch.float32
)
pr_point_fix[:, 0, 0] = 1.0
pr_point_fix[:, 1, 1] = 1.0
pr_point_fix[:, 2, 2] = 1.0
pr_point_fix[:, 3, 3] = 1.0
pr_point_fix[:, :2, 3] = -2.0 * self.get_principal_point(**kwargs)
pr_point_fix_transform = Transform3d(
matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device
)
screen_to_ndc_transform = get_screen_to_ndc_transform(
self, with_xyflip=False, **kwargs
)
ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform)
return ndc_transform
def is_perspective(self):
return True
def in_ndc(self):
return self._in_ndc
def SfMOrthographicCameras(
focal_length=1.0,
@ -989,29 +1091,8 @@ class OrthographicCameras(CamerasBase):
transformation matrices using the multi-view geometry convention for
orthographic camera.
Parameters for this camera can be specified in NDC or in screen space.
If you wish to provide parameters in screen space, you NEED to provide
the image_size = (imwidth, imheight).
If you wish to provide parameters in NDC space, you should NOT provide
image_size. Providing valid image_size will trigger a screen space to
NDC space transformation in the camera.
For example, here is how to define cameras on the two spaces.
.. code-block:: python
# camera defined in screen space
cameras = OrthographicCameras(
focal_length=((22.0, 15.0),), # (fx, fy)
principal_point=((192.0, 128.0),), # (px, py)
image_size=((256, 256),), # (imwidth, imheight)
)
# the equivalent camera defined in NDC space
cameras = OrthographicCameras(
focal_length=((0.17875, 0.11718),), # := (fx / half_imwidth, fy / half_imheight)
principal_point=((-0.5, 0),), # := (- (px - half_imwidth) / half_imwidth,
- (py - half_imheight) / half_imheight)
)
Parameters for this camera are specified in NDC if `in_ndc` is set to True.
If parameters are specified in screen space, `in_ndc` must be set to False.
"""
def __init__(
@ -1022,7 +1103,8 @@ class OrthographicCameras(CamerasBase):
T: torch.Tensor = _T,
K: Optional[torch.Tensor] = None,
device: Device = "cpu",
image_size=((-1, -1),),
in_ndc: bool = True,
image_size: Optional[torch.Tensor] = None,
) -> None:
"""
@ -1033,19 +1115,19 @@ class OrthographicCameras(CamerasBase):
principal_point: xy coordinates of the center of
the principal point of the camera in pixels.
A tensor of shape (N, 2).
in_ndc: True if camera parameters are specified in NDC.
If False, then camera parameters are in screen space.
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
K: (optional) A calibration matrix of shape (N, 4, 4)
If provided, don't need focal_length, principal_point, image_size
image_size: (height, width) of image size.
A tensor of shape (N, 2). Required for screen cameras.
device: torch.device or string
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
is provided, the camera parameters are assumed to be in screen
space. They will be converted to NDC space.
If image_size is not provided, the parameters are assumed to
be in NDC space.
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
kwargs = {"image_size": image_size} if image_size is not None else {}
super().__init__(
device=device,
focal_length=focal_length,
@ -1053,8 +1135,14 @@ class OrthographicCameras(CamerasBase):
R=R,
T=T,
K=K,
image_size=image_size,
_in_ndc=in_ndc,
**kwargs, # pyre-ignore
)
if image_size is not None:
if (self.image_size < 1).any(): # pyre-ignore
raise ValueError("Image_size provided has invalid values")
else:
self.image_size = None
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
@ -1088,37 +1176,83 @@ class OrthographicCameras(CamerasBase):
msg = "Expected K to have shape of (%r, 4, 4)"
raise ValueError(msg % (self._N))
else:
image_size = kwargs.get("image_size", self.image_size)
# if imwidth > 0, parameters are in screen space
image_size = image_size if image_size[0][0] > 0 else None
K = _get_sfm_calibration_matrix(
self._N,
self.device,
kwargs.get("focal_length", self.focal_length),
kwargs.get("principal_point", self.principal_point),
orthographic=True,
image_size=image_size,
)
transform = Transform3d(device=self.device)
transform._matrix = K.transpose(1, 2).contiguous()
transform = Transform3d(
matrix=K.transpose(1, 2).contiguous(), device=self.device
)
return transform
def unproject_points(
self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs
) -> torch.Tensor:
if world_coordinates:
to_ndc_transform = self.get_full_projection_transform(**kwargs)
to_camera_transform = self.get_full_projection_transform(**kwargs)
else:
to_ndc_transform = self.get_projection_transform(**kwargs)
to_camera_transform = self.get_projection_transform(**kwargs)
unprojection_transform = to_ndc_transform.inverse()
unprojection_transform = to_camera_transform.inverse()
return unprojection_transform.transform_points(xy_depth)
def get_principal_point(self, **kwargs) -> torch.Tensor:
"""
Return the camera's principal point
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
"""
proj_mat = self.get_projection_transform(**kwargs).get_matrix()
return proj_mat[:, 3, :2]
def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
"""
Returns the transform from camera projection space (screen or NDC) to NDC space.
If the camera is defined already in NDC space, the transform is identity.
For cameras defined in screen space, we adjust the principal point computation
which is defined in the image space (commonly) and scale the points to NDC space.
Important: This transforms assumes PyTorch3D conventions for the input points,
i.e. +X left, +Y up.
"""
if self.in_ndc():
ndc_transform = Transform3d(device=self.device, dtype=torch.float32)
else:
# when cameras are defined in screen/image space, the principal point is
# provided in the (+X right, +Y down), aka image, coordinate system.
# Since input points are defined in the PyTorch3D system (+X left, +Y up),
# we need to adjust for the principal point transform.
pr_point_fix = torch.zeros(
(self._N, 4, 4), device=self.device, dtype=torch.float32
)
pr_point_fix[:, 0, 0] = 1.0
pr_point_fix[:, 1, 1] = 1.0
pr_point_fix[:, 2, 2] = 1.0
pr_point_fix[:, 3, 3] = 1.0
pr_point_fix[:, :2, 3] = -2.0 * self.get_principal_point(**kwargs)
pr_point_fix_transform = Transform3d(
matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device
)
screen_to_ndc_transform = get_screen_to_ndc_transform(
self, with_xyflip=False, **kwargs
)
ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform)
return ndc_transform
def is_perspective(self):
return False
def in_ndc(self):
return self._in_ndc
################################################
# Helper functions for cameras #
@ -1131,20 +1265,16 @@ def _get_sfm_calibration_matrix(
focal_length,
principal_point,
orthographic: bool = False,
image_size=None,
) -> torch.Tensor:
"""
Returns a calibration matrix of a perspective/orthographic camera.
Args:
N: Number of cameras.
focal_length: Focal length of the camera in world units.
focal_length: Focal length of the camera.
principal_point: xy coordinates of the center of
the principal point of the camera in pixels.
orthographic: Boolean specifying if the camera is orthographic or not
image_size: (Optional) Specifying the image_size = (imwidth, imheight).
If not None, the camera parameters are assumed to be in screen space
and are transformed to NDC space.
The calibration matrix `K` is set up as follows:
@ -1188,22 +1318,6 @@ def _get_sfm_calibration_matrix(
px, py = principal_point.unbind(1)
if image_size is not None:
if not torch.is_tensor(image_size):
image_size = torch.tensor(image_size, device=device)
imwidth, imheight = image_size.unbind(1)
# make sure imwidth, imheight are valid (>0)
if (imwidth < 1).any() or (imheight < 1).any():
raise ValueError(
"Camera parameters provided in screen space. Image width or height invalid."
)
half_imwidth = imwidth / 2.0
half_imheight = imheight / 2.0
fx = fx / half_imwidth
fy = fy / half_imheight
px = -(px - half_imwidth) / half_imwidth
py = -(py - half_imheight) / half_imheight
K = fx.new_zeros(N, 4, 4)
K[:, 0, 0] = fx
K[:, 1, 1] = fy
@ -1419,3 +1533,103 @@ def look_at_view_transform(
R = look_at_rotation(C, at, up, device=device)
T = -torch.bmm(R.transpose(1, 2), C[:, :, None])[:, :, 0]
return R, T
def get_ndc_to_screen_transform(
cameras, with_xyflip: bool = False, **kwargs
) -> Transform3d:
"""
PyTorch3D NDC to screen conversion.
Conversion from PyTorch3D's NDC space (+X left, +Y up) to screen/image space
(+X right, +Y down, origin top left).
Args:
cameras
with_xyflip: flips x- and y-axis if set to True.
Optional kwargs:
image_size: ((height, width),) specifying the height, width
of the image. If not provided, it reads it from cameras.
We represent the NDC to screen conversion as a Transform3d
with projection matrix
K = [
[s, 0, 0, cx],
[0, s, 0, cy],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
"""
# We require the image size, which is necessary for the transform
image_size = kwargs.get("image_size", cameras.get_image_size())
if image_size is None:
msg = "For NDC to screen conversion, image_size=(height, width) needs to be specified."
raise ValueError(msg)
K = torch.zeros((cameras._N, 4, 4), device=cameras.device, dtype=torch.float32)
if not torch.is_tensor(image_size):
image_size = torch.tensor(image_size, device=cameras.device)
image_size = image_size.view(-1, 2) # of shape (1 or B)x2
height, width = image_size.unbind(1)
# For non square images, we scale the points such that smallest side
# has range [-1, 1] and the largest side has range [-u, u], with u > 1.
# This convention is consistent with the PyTorch3D renderer
scale = (image_size.min(dim=1).values - 1.0) / 2.0
K[:, 0, 0] = scale
K[:, 1, 1] = scale
K[:, 0, 3] = -1.0 * (width - 1.0) / 2.0
K[:, 1, 3] = -1.0 * (height - 1.0) / 2.0
K[:, 2, 2] = 1.0
K[:, 3, 3] = 1.0
# Transpose the projection matrix as PyTorch3D transforms use row vectors.
transform = Transform3d(
matrix=K.transpose(1, 2).contiguous(), device=cameras.device
)
if with_xyflip:
# flip x, y axis
xyflip = torch.eye(4, device=cameras.device, dtype=torch.float32)
xyflip[0, 0] = -1.0
xyflip[1, 1] = -1.0
xyflip = xyflip.view(1, 4, 4).expand(cameras._N, -1, -1)
xyflip_transform = Transform3d(
matrix=xyflip.transpose(1, 2).contiguous(), device=cameras.device
)
transform = transform.compose(xyflip_transform)
return transform
def get_screen_to_ndc_transform(
cameras, with_xyflip: bool = False, **kwargs
) -> Transform3d:
"""
Screen to PyTorch3D NDC conversion.
Conversion from screen/image space (+X right, +Y down, origin top left)
to PyTorch3D's NDC space (+X left, +Y up).
Args:
cameras
with_xyflip: flips x- and y-axis if set to True.
Optional kwargs:
image_size: ((height, width),) specifying the height, width
of the image. If not provided, it reads it from cameras.
We represent the screen to NDC conversion as a Transform3d
with projection matrix
K = [
[1/s, 0, 0, cx/s],
[ 0, 1/s, 0, cy/s],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1],
]
"""
transform = get_ndc_to_screen_transform(
cameras, with_xyflip=with_xyflip, **kwargs
).inverse()
return transform

View File

@ -73,8 +73,7 @@ class MeshRasterizer(nn.Module):
Args:
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
world-to-view and view-to-screen
transformations.
world-to-view and view-to-ndc transformations.
raster_settings: the parameters for rasterization. This should be a
named tuple.
@ -100,8 +99,8 @@ class MeshRasterizer(nn.Module):
vertex coordinates in world space.
Returns:
meshes_screen: a Meshes object with the vertex positions in screen
space
meshes_proj: a Meshes object with the vertex positions projected
in NDC space
NOTE: keeping this as a separate function for readability but it could
be moved into forward.
@ -126,12 +125,14 @@ class MeshRasterizer(nn.Module):
verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
verts_world, eps=eps
)
verts_screen = cameras.get_projection_transform(**kwargs).transform_points(
verts_view, eps=eps
)
verts_screen[..., 2] = verts_view[..., 2]
meshes_screen = meshes_world.update_padded(new_verts_padded=verts_screen)
return meshes_screen
# view to NDC transform
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
projection_transform = cameras.get_projection_transform(**kwargs).compose(to_ndc_transform)
verts_ndc = projection_transform.transform_points(verts_view, eps=eps)
verts_ndc[..., 2] = verts_view[..., 2]
meshes_ndc = meshes_world.update_padded(new_verts_padded=verts_ndc)
return meshes_ndc
def forward(self, meshes_world, **kwargs) -> Fragments:
"""
@ -141,7 +142,7 @@ class MeshRasterizer(nn.Module):
Returns:
Fragments: Rasterization outputs as a named tuple.
"""
meshes_screen = self.transform(meshes_world, **kwargs)
meshes_proj = self.transform(meshes_world, **kwargs)
raster_settings = kwargs.get("raster_settings", self.raster_settings)
# By default, turn on clip_barycentric_coords if blur_radius > 0.
@ -166,7 +167,7 @@ class MeshRasterizer(nn.Module):
z_clip = None if not perspective_correct or znear is None else znear / 2
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
meshes_screen,
meshes_proj,
image_size=raster_settings.image_size,
blur_radius=raster_settings.blur_radius,
faces_per_pixel=raster_settings.faces_per_pixel,

View File

@ -55,8 +55,7 @@ class PointsRasterizer(nn.Module):
"""
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
world-to-view and view-to-screen
transformations.
world-to-view and view-to-ndc transformations.
raster_settings: the parameters for rasterization. This should be a
named tuple.
@ -76,8 +75,8 @@ class PointsRasterizer(nn.Module):
point_clouds: a set of point clouds
Returns:
points_screen: the points with the vertex positions in screen
space
points_proj: the points with positions projected
in NDC space
NOTE: keeping this as a separate function for readability but it could
be moved into forward.
@ -93,14 +92,17 @@ class PointsRasterizer(nn.Module):
# TODO: Remove this line when the convention for the z coordinate in
# the rasterizer is decided. i.e. retain z in view space or transform
# to a different range.
eps = kwargs.get("eps", None)
pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
pts_world
pts_world, eps=eps
)
pts_screen = cameras.get_projection_transform(**kwargs).transform_points(
pts_view
)
pts_screen[..., 2] = pts_view[..., 2]
point_clouds = point_clouds.update_padded(pts_screen)
# view to NDC transform
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
projection_transform = cameras.get_projection_transform(**kwargs).compose(to_ndc_transform)
pts_ndc = projection_transform.transform_points(pts_view, eps=eps)
pts_ndc[..., 2] = pts_view[..., 2]
point_clouds = point_clouds.update_padded(pts_ndc)
return point_clouds
def to(self, device):
@ -115,10 +117,10 @@ class PointsRasterizer(nn.Module):
Returns:
PointFragments: Rasterization outputs as a named tuple.
"""
points_screen = self.transform(point_clouds, **kwargs)
points_proj = self.transform(point_clouds, **kwargs)
raster_settings = kwargs.get("raster_settings", self.raster_settings)
idx, zbuf, dists2 = rasterize_points(
points_screen,
points_proj,
image_size=raster_settings.image_size,
radius=raster_settings.radius,
points_per_pixel=raster_settings.points_per_pixel,

View File

@ -124,17 +124,23 @@ def ndc_to_screen_points_naive(points, imsize):
Transforms points from PyTorch3D's NDC space to screen space
Args:
points: (N, V, 3) representing padded points
imsize: (N, 2) image size = (width, height)
imsize: (N, 2) image size = (height, width)
Returns:
(N, V, 3) tensor of transformed points
"""
imwidth, imheight = imsize.unbind(1)
imwidth = imwidth.view(-1, 1)
imheight = imheight.view(-1, 1)
height, width = imsize.unbind(1)
width = width.view(-1, 1)
half_width = (width - 1.0) / 2.0
height = height.view(-1, 1)
half_height = (height - 1.0) / 2.0
scale = (
half_width * (height > width).float() + half_height * (height <= width).float()
)
x, y, z = points.unbind(2)
x = (1.0 - x) * (imwidth - 1) / 2.0
y = (1.0 - y) * (imheight - 1) / 2.0
x = -scale * x + half_width
y = -scale * y + half_height
return torch.stack((x, y, z), dim=2)
@ -513,17 +519,23 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
screen_cam_params = {"R": R, "T": T}
ndc_cam_params = {"R": R, "T": T}
if cam_type in (OrthographicCameras, PerspectiveCameras):
ndc_cam_params["focal_length"] = torch.rand((batch_size, 2)) * 3.0
ndc_cam_params["principal_point"] = torch.randn((batch_size, 2))
fcl = torch.rand((batch_size, 2)) * 3.0 + 0.1
prc = torch.randn((batch_size, 2)) * 0.2
# (height, width)
image_size = torch.randint(low=2, high=64, size=(batch_size, 2))
# scale
scale = (image_size.min(dim=1, keepdim=True).values - 1.0) / 2.0
ndc_cam_params["focal_length"] = fcl
ndc_cam_params["principal_point"] = prc
ndc_cam_params["image_size"] = image_size
screen_cam_params["image_size"] = image_size
screen_cam_params["focal_length"] = (
ndc_cam_params["focal_length"] * image_size / 2.0
)
screen_cam_params["focal_length"] = fcl * scale
screen_cam_params["principal_point"] = (
(1.0 - ndc_cam_params["principal_point"]) * image_size / 2.0
)
image_size[:, [1, 0]] - 1.0
) / 2.0 - prc * scale
screen_cam_params["in_ndc"] = False
else:
raise ValueError(str(cam_type))
return cam_type(**ndc_cam_params), cam_type(**screen_cam_params)
@ -611,17 +623,22 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
# init the cameras
cameras = init_random_cameras(cam_type, batch_size)
# xyz - the ground truth point cloud
xyz = torch.randn(batch_size, num_points, 3) * 0.3
xy = torch.randn(batch_size, num_points, 2) * 2.0 - 1.0
z = torch.randn(batch_size, num_points, 1) * 3.0 + 1.0
xyz = torch.cat((xy, z), dim=2)
# image size
image_size = torch.randint(low=2, high=64, size=(batch_size, 2))
image_size = torch.randint(low=32, high=64, size=(batch_size, 2))
# project points
xyz_project_ndc = cameras.transform_points(xyz)
xyz_project_screen = cameras.transform_points_screen(xyz, image_size)
xyz_project_ndc = cameras.transform_points_ndc(xyz)
xyz_project_screen = cameras.transform_points_screen(
xyz, image_size=image_size
)
# naive
xyz_project_screen_naive = ndc_to_screen_points_naive(
xyz_project_ndc, image_size
)
self.assertClose(xyz_project_screen, xyz_project_screen_naive)
# we set atol to 1e-4, remember that screen points are in [0, W-1]x[0, H-1] space
self.assertClose(xyz_project_screen, xyz_project_screen_naive, atol=1e-4)
def test_equiv_project_points(self, batch_size=50, num_points=100):
"""
@ -634,12 +651,15 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
ndc_cameras,
screen_cameras,
) = TestCamerasCommon.init_equiv_cameras_ndc_screen(cam_type, batch_size)
# xyz - the ground truth point cloud
xyz = torch.randn(batch_size, num_points, 3) * 0.3
# xyz - the ground truth point cloud in Py3D space
xy = torch.randn(batch_size, num_points, 2) * 0.3
z = torch.rand(batch_size, num_points, 1) + 3.0 + 0.1
xyz = torch.cat((xy, z), dim=2)
# project points
xyz_ndc_cam = ndc_cameras.transform_points(xyz)
xyz_screen_cam = screen_cameras.transform_points(xyz)
self.assertClose(xyz_ndc_cam, xyz_screen_cam, atol=1e-6)
xyz_ndc = ndc_cameras.transform_points_ndc(xyz)
xyz_screen = screen_cameras.transform_points_ndc(xyz)
# check correctness
self.assertClose(xyz_ndc, xyz_screen, atol=1e-5)
def test_clone(self, batch_size: int = 10):
"""

View File

@ -255,9 +255,20 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
device=device,
R=R,
T=T,
principal_point=((256.0, 256.0),),
focal_length=((256.0, 256.0),),
principal_point=(
(
(512.0 - 1.0) / 2.0,
(512.0 - 1.0) / 2.0,
),
),
focal_length=(
(
(512.0 - 1.0) / 2.0,
(512.0 - 1.0) / 2.0,
),
),
image_size=((512, 512),),
in_ndc=False,
)
rasterizer = MeshRasterizer(
cameras=cameras, raster_settings=raster_settings