mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
NDC/screen cameras API fix, compatibility with renderer
Summary: API fix for NDC/screen cameras and compatibility with PyTorch3D renderers. With this new fix: * Users can define cameras and `transform_points` under any coordinate system conventions. The transformation applies the camera K and RT to the input points, not regarding for PyTorch3D conventions. So this makes cameras completely independent from PyTorch3D renderer. * Cameras can be defined either in NDC space or screen space. For existing ones, FoV cameras are in NDC space. Perspective/Orthographic can be defined in NDC or screen space. * The interface with PyTorch3D renderers happens through `transform_points_ndc` which transforms points to the NDC space and assumes that input points are provided according to PyTorch3D conventions. * Similarly, `transform_points_screen` transforms points to screen space and again assumes that input points are under PyTorch3D conventions. * For Orthographic/Perspective cameras, if they are defined in screen space, the `get_ndc_camera_transform` allows points to be converted to NDC for use for the renderers. Reviewed By: nikhilaravi Differential Revision: D26932657 fbshipit-source-id: 1a964e3e7caa54d10c792cf39c4d527ba2fb2e79
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9a14f54e8b
commit
0c32f094af
@@ -6,7 +6,7 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Sequence, Tuple
|
||||
from typing import Optional, Sequence, Tuple, Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -28,20 +28,20 @@ class CamerasBase(TensorProperties):
|
||||
|
||||
For cameras, there are four different coordinate systems (or spaces)
|
||||
- World coordinate system: This is the system the object lives - the world.
|
||||
- Camera view coordinate system: This is the system that has its origin on the image plane
|
||||
- Camera view coordinate system: This is the system that has its origin on the camera
|
||||
and the and the Z-axis perpendicular to the image plane.
|
||||
In PyTorch3D, we assume that +X points left, and +Y points up and
|
||||
+Z points out from the image plane.
|
||||
The transformation from world -> view happens after applying a rotation (R)
|
||||
The transformation from world --> view happens after applying a rotation (R)
|
||||
and translation (T)
|
||||
- NDC coordinate system: This is the normalized coordinate system that confines
|
||||
in a volume the rendered part of the object or scene. Also known as view volume.
|
||||
Given the PyTorch3D convention, (+1, +1, znear) is the top left near corner,
|
||||
and (-1, -1, zfar) is the bottom right far corner of the volume.
|
||||
The transformation from view -> NDC happens after applying the camera
|
||||
projection matrix (P).
|
||||
The transformation from view --> NDC happens after applying the camera
|
||||
projection matrix (P) if defined in NDC space.
|
||||
- Screen coordinate system: This is another representation of the view volume with
|
||||
the XY coordinates defined in pixel space instead of a normalized space.
|
||||
the XY coordinates defined in image space instead of a normalized space.
|
||||
|
||||
A better illustration of the coordinate systems can be found in
|
||||
pytorch3d/docs/notes/cameras.md.
|
||||
@@ -54,17 +54,21 @@ class CamerasBase(TensorProperties):
|
||||
- `get_full_projection_transform` which composes the projection
|
||||
transform (P) with the world-to-view transform (R, T)
|
||||
- `transform_points` which takes a set of input points in world coordinates and
|
||||
projects to NDC coordinates ranging from [-1, -1, znear] to [+1, +1, zfar].
|
||||
- `transform_points_screen` which takes a set of input points in world coordinates and
|
||||
projects them to the screen coordinates ranging from
|
||||
[0, 0, znear] to [W-1, H-1, zfar]
|
||||
projects to the space the camera is defined in (NDC or screen)
|
||||
- `get_ndc_camera_transform` which defines the transform from screen/NDC to
|
||||
PyTorch3D's NDC space
|
||||
- `transform_points_ndc` which takes a set of points in world coordinates and
|
||||
projects them to PyTorch3D's NDC space
|
||||
- `transform_points_screen` which takes a set of points in world coordinates and
|
||||
projects them to screen space
|
||||
|
||||
For each new camera, one should implement the `get_projection_transform`
|
||||
routine that returns the mapping from camera view coordinates to NDC coordinates.
|
||||
routine that returns the mapping from camera view coordinates to camera
|
||||
coordinates (NDC or screen).
|
||||
|
||||
Another useful function that is specific to each camera model is
|
||||
`unproject_points` which sends points from NDC coordinates back to
|
||||
camera view or world coordinates depending on the `world_coordinates`
|
||||
`unproject_points` which sends points from camera coordinates (NDC or screen)
|
||||
back to camera view or world coordinates depending on the `world_coordinates`
|
||||
boolean argument of the function.
|
||||
"""
|
||||
|
||||
@@ -84,7 +88,7 @@ class CamerasBase(TensorProperties):
|
||||
|
||||
def unproject_points(self):
|
||||
"""
|
||||
Transform input points from NDC coordinates
|
||||
Transform input points from camera coodinates (NDC or screen)
|
||||
to the world / camera coordinates.
|
||||
|
||||
Each of the input points `xy_depth` of shape (..., 3) is
|
||||
@@ -181,8 +185,10 @@ class CamerasBase(TensorProperties):
|
||||
|
||||
def get_full_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Return the full world-to-NDC transform composing the
|
||||
world-to-view and view-to-NDC transforms.
|
||||
Return the full world-to-camera transform composing the
|
||||
world-to-view and view-to-camera transforms.
|
||||
If camera is defined in NDC space, the projected points are in NDC space.
|
||||
If camera is defined in screen space, the projected points are in screen space.
|
||||
|
||||
Args:
|
||||
**kwargs: parameters for the projection transforms can be passed in
|
||||
@@ -200,14 +206,70 @@ class CamerasBase(TensorProperties):
|
||||
self.R: torch.Tensor = kwargs.get("R", self.R) # pyre-ignore[16]
|
||||
self.T: torch.Tensor = kwargs.get("T", self.T) # pyre-ignore[16]
|
||||
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
|
||||
view_to_ndc_transform = self.get_projection_transform(**kwargs)
|
||||
return world_to_view_transform.compose(view_to_ndc_transform)
|
||||
view_to_proj_transform = self.get_projection_transform(**kwargs)
|
||||
return world_to_view_transform.compose(view_to_proj_transform)
|
||||
|
||||
def transform_points(
|
||||
self, points, eps: Optional[float] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Transform input points from world to NDC space.
|
||||
Transform input points from world to camera space with the
|
||||
projection matrix defined by the camera.
|
||||
|
||||
For `CamerasBase.transform_points`, setting `eps > 0`
|
||||
stabilizes gradients since it leads to avoiding division
|
||||
by excessively low numbers for points close to the camera plane.
|
||||
|
||||
Args:
|
||||
points: torch tensor of shape (..., 3).
|
||||
eps: If eps!=None, the argument is used to clamp the
|
||||
divisor in the homogeneous normalization of the points
|
||||
transformed to the ndc space. Please see
|
||||
`transforms.Transform3D.transform_points` for details.
|
||||
|
||||
For `CamerasBase.transform_points`, setting `eps > 0`
|
||||
stabilizes gradients since it leads to avoiding division
|
||||
by excessively low numbers for points close to the
|
||||
camera plane.
|
||||
|
||||
Returns
|
||||
new_points: transformed points with the same shape as the input.
|
||||
"""
|
||||
world_to_proj_transform = self.get_full_projection_transform(**kwargs)
|
||||
return world_to_proj_transform.transform_points(points, eps=eps)
|
||||
|
||||
def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Returns the transform from camera projection space (screen or NDC) to NDC space.
|
||||
For cameras that can be specified in screen space, this transform
|
||||
allows points to be converted from screen to NDC space.
|
||||
The default transform scales the points from [0, W-1]x[0, H-1] to [-1, 1].
|
||||
This function should be modified per camera definitions if need be,
|
||||
e.g. for Perspective/Orthographic cameras we provide a custom implementation.
|
||||
This transform assumes PyTorch3D coordinate system conventions for
|
||||
both the NDC space and the input points.
|
||||
|
||||
This transform interfaces with the PyTorch3D renderer which assumes
|
||||
input points to the renderer to be in NDC space.
|
||||
"""
|
||||
if self.in_ndc():
|
||||
return Transform3d(device=self.device, dtype=torch.float32)
|
||||
else:
|
||||
# For custom cameras which can be defined in screen space,
|
||||
# users might might have to implement the screen to NDC transform based
|
||||
# on the definition of the camera parameters.
|
||||
# See PerspectiveCameras/OrthographicCameras for an example.
|
||||
# We don't flip xy because we assume that world points are in PyTorch3D coodrinates
|
||||
# and thus conversion from screen to ndc is a mere scaling from image to [-1, 1] scale.
|
||||
return get_screen_to_ndc_transform(self, with_xyflip=False, **kwargs)
|
||||
|
||||
def transform_points_ndc(
|
||||
self, points, eps: Optional[float] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Transforms points from PyTorch3D world/camera space to NDC space.
|
||||
Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up.
|
||||
Output points are in NDC space: +X left, +Y up, origin at image center.
|
||||
|
||||
Args:
|
||||
points: torch tensor of shape (..., 3).
|
||||
@@ -225,17 +287,22 @@ class CamerasBase(TensorProperties):
|
||||
new_points: transformed points with the same shape as the input.
|
||||
"""
|
||||
world_to_ndc_transform = self.get_full_projection_transform(**kwargs)
|
||||
if not self.in_ndc():
|
||||
to_ndc_transform = self.get_ndc_camera_transform(**kwargs)
|
||||
world_to_ndc_transform = world_to_ndc_transform.compose(to_ndc_transform)
|
||||
|
||||
return world_to_ndc_transform.transform_points(points, eps=eps)
|
||||
|
||||
def transform_points_screen(
|
||||
self, points, image_size, eps: Optional[float] = None, **kwargs
|
||||
self, points, eps: Optional[float] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Transform input points from world to screen space.
|
||||
Transforms points from PyTorch3D world/camera space to screen space.
|
||||
Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up.
|
||||
Output points are in screen space: +X right, +Y down, origin at top left corner.
|
||||
|
||||
Args:
|
||||
points: torch tensor of shape (N, V, 3).
|
||||
image_size: torch tensor of shape (N, 2)
|
||||
points: torch tensor of shape (..., 3).
|
||||
eps: If eps!=None, the argument is used to clamp the
|
||||
divisor in the homogeneous normalization of the points
|
||||
transformed to the ndc space. Please see
|
||||
@@ -249,25 +316,10 @@ class CamerasBase(TensorProperties):
|
||||
Returns
|
||||
new_points: transformed points with the same shape as the input.
|
||||
"""
|
||||
|
||||
ndc_points = self.transform_points(points, eps=eps, **kwargs)
|
||||
|
||||
if not torch.is_tensor(image_size):
|
||||
image_size = torch.tensor(
|
||||
image_size, dtype=torch.int64, device=points.device
|
||||
)
|
||||
if (image_size < 1).any():
|
||||
raise ValueError("Provided image size is invalid.")
|
||||
|
||||
image_width, image_height = image_size.unbind(1)
|
||||
image_width = image_width.view(-1, 1) # (N, 1)
|
||||
image_height = image_height.view(-1, 1) # (N, 1)
|
||||
|
||||
ndc_z = ndc_points[..., 2]
|
||||
screen_x = (image_width - 1.0) / 2.0 * (1.0 - ndc_points[..., 0])
|
||||
screen_y = (image_height - 1.0) / 2.0 * (1.0 - ndc_points[..., 1])
|
||||
|
||||
return torch.stack((screen_x, screen_y, ndc_z), dim=2)
|
||||
points_ndc = self.transform_points_ndc(points, eps=eps, **kwargs)
|
||||
return get_ndc_to_screen_transform(
|
||||
self, with_xyflip=True, **kwargs
|
||||
).transform_points(points_ndc, eps=eps)
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
@@ -280,9 +332,23 @@ class CamerasBase(TensorProperties):
|
||||
def is_perspective(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def in_ndc(self):
|
||||
"""
|
||||
Specifies whether the camera is defined in NDC space
|
||||
or in screen (image) space
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_znear(self):
|
||||
return self.znear if hasattr(self, "znear") else None
|
||||
|
||||
def get_image_size(self):
|
||||
"""
|
||||
Returns the image size, if provided, expected in the form of (height, width)
|
||||
The image size is used for conversion of projected points to screen coordinates.
|
||||
"""
|
||||
return self.image_size if hasattr(self, "image_size") else None
|
||||
|
||||
|
||||
############################################################
|
||||
# Field of View Camera Classes #
|
||||
@@ -501,8 +567,9 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
)
|
||||
|
||||
# Transpose the projection matrix as PyTorch3D transforms use row vectors.
|
||||
transform = Transform3d(device=self.device)
|
||||
transform._matrix = K.transpose(1, 2).contiguous()
|
||||
transform = Transform3d(
|
||||
matrix=K.transpose(1, 2).contiguous(), device=self.device
|
||||
)
|
||||
return transform
|
||||
|
||||
def unproject_points(
|
||||
@@ -552,6 +619,9 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
def is_perspective(self):
|
||||
return True
|
||||
|
||||
def in_ndc(self):
|
||||
return True
|
||||
|
||||
|
||||
def OpenGLOrthographicCameras(
|
||||
znear=1.0,
|
||||
@@ -726,8 +796,9 @@ class FoVOrthographicCameras(CamerasBase):
|
||||
kwargs.get("scale_xyz", self.scale_xyz),
|
||||
)
|
||||
|
||||
transform = Transform3d(device=self.device)
|
||||
transform._matrix = K.transpose(1, 2).contiguous()
|
||||
transform = Transform3d(
|
||||
matrix=K.transpose(1, 2).contiguous(), device=self.device
|
||||
)
|
||||
return transform
|
||||
|
||||
def unproject_points(
|
||||
@@ -773,15 +844,15 @@ class FoVOrthographicCameras(CamerasBase):
|
||||
def is_perspective(self):
|
||||
return False
|
||||
|
||||
def in_ndc(self):
|
||||
return True
|
||||
|
||||
|
||||
############################################################
|
||||
# MultiView Camera Classes #
|
||||
############################################################
|
||||
"""
|
||||
Note that the MultiView Cameras accept parameters in both
|
||||
screen and NDC space.
|
||||
If the user specifies `image_size` at construction time then
|
||||
we assume the parameters are in screen space.
|
||||
Note that the MultiView Cameras accept parameters in NDC space.
|
||||
"""
|
||||
|
||||
|
||||
@@ -819,30 +890,8 @@ class PerspectiveCameras(CamerasBase):
|
||||
transformation matrices using the multi-view geometry convention for
|
||||
perspective camera.
|
||||
|
||||
Parameters for this camera can be specified in NDC or in screen space.
|
||||
If you wish to provide parameters in screen space, you NEED to provide
|
||||
the image_size = (imwidth, imheight).
|
||||
If you wish to provide parameters in NDC space, you should NOT provide
|
||||
image_size. Providing valid image_size will trigger a screen space to
|
||||
NDC space transformation in the camera.
|
||||
|
||||
For example, here is how to define cameras on the two spaces.
|
||||
|
||||
.. code-block:: python
|
||||
# camera defined in screen space
|
||||
cameras = PerspectiveCameras(
|
||||
focal_length=((22.0, 15.0),), # (fx_screen, fy_screen)
|
||||
principal_point=((192.0, 128.0),), # (px_screen, py_screen)
|
||||
image_size=((256, 256),), # (imwidth, imheight)
|
||||
)
|
||||
|
||||
# the equivalent camera defined in NDC space
|
||||
cameras = PerspectiveCameras(
|
||||
focal_length=((0.17875, 0.11718),), # fx = fx_screen / half_imwidth,
|
||||
# fy = fy_screen / half_imheight
|
||||
principal_point=((-0.5, 0),), # px = - (px_screen - half_imwidth) / half_imwidth,
|
||||
# py = - (py_screen - half_imheight) / half_imheight
|
||||
)
|
||||
Parameters for this camera are specified in NDC if `in_ndc` is set to True.
|
||||
If parameters are specified in screen space, `in_ndc` must be set to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -853,7 +902,8 @@ class PerspectiveCameras(CamerasBase):
|
||||
T: torch.Tensor = _T,
|
||||
K: Optional[torch.Tensor] = None,
|
||||
device: Device = "cpu",
|
||||
image_size=((-1, -1),),
|
||||
in_ndc: bool = True,
|
||||
image_size: Optional[Union[List, Tuple, torch.Tensor]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@@ -864,20 +914,20 @@ class PerspectiveCameras(CamerasBase):
|
||||
principal_point: xy coordinates of the center of
|
||||
the principal point of the camera in pixels.
|
||||
A tensor of shape (N, 2).
|
||||
in_ndc: True if camera parameters are specified in NDC.
|
||||
If camera parameters are in screen space, it must
|
||||
be set to False.
|
||||
R: Rotation matrix of shape (N, 3, 3)
|
||||
T: Translation matrix of shape (N, 3)
|
||||
K: (optional) A calibration matrix of shape (N, 4, 4)
|
||||
If provided, don't need focal_length, principal_point, image_size
|
||||
|
||||
If provided, don't need focal_length, principal_point
|
||||
image_size: (height, width) of image size.
|
||||
A tensor of shape (N, 2). Required for screen cameras.
|
||||
device: torch.device or string
|
||||
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
|
||||
is provided, the camera parameters are assumed to be in screen
|
||||
space. They will be converted to NDC space.
|
||||
If image_size is not provided, the parameters are assumed to
|
||||
be in NDC space.
|
||||
"""
|
||||
# The initializer formats all inputs to torch tensors and broadcasts
|
||||
# all the inputs to have the same batch dimension where necessary.
|
||||
kwargs = {"image_size": image_size} if image_size is not None else {}
|
||||
super().__init__(
|
||||
device=device,
|
||||
focal_length=focal_length,
|
||||
@@ -885,8 +935,14 @@ class PerspectiveCameras(CamerasBase):
|
||||
R=R,
|
||||
T=T,
|
||||
K=K,
|
||||
image_size=image_size,
|
||||
_in_ndc=in_ndc,
|
||||
**kwargs, # pyre-ignore
|
||||
)
|
||||
if image_size is not None:
|
||||
if (self.image_size < 1).any(): # pyre-ignore
|
||||
raise ValueError("Image_size provided has invalid values")
|
||||
else:
|
||||
self.image_size = None
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
@@ -920,40 +976,86 @@ class PerspectiveCameras(CamerasBase):
|
||||
msg = "Expected K to have shape of (%r, 4, 4)"
|
||||
raise ValueError(msg % (self._N))
|
||||
else:
|
||||
image_size = kwargs.get("image_size", self.image_size)
|
||||
# if imwidth > 0, parameters are in screen space
|
||||
image_size = image_size if image_size[0][0] > 0 else None
|
||||
|
||||
K = _get_sfm_calibration_matrix(
|
||||
self._N,
|
||||
self.device,
|
||||
kwargs.get("focal_length", self.focal_length),
|
||||
kwargs.get("principal_point", self.principal_point),
|
||||
orthographic=False,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
transform = Transform3d(device=self.device)
|
||||
transform._matrix = K.transpose(1, 2).contiguous()
|
||||
transform = Transform3d(
|
||||
matrix=K.transpose(1, 2).contiguous(), device=self.device
|
||||
)
|
||||
return transform
|
||||
|
||||
def unproject_points(
|
||||
self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs
|
||||
) -> torch.Tensor:
|
||||
if world_coordinates:
|
||||
to_ndc_transform = self.get_full_projection_transform(**kwargs)
|
||||
to_camera_transform = self.get_full_projection_transform(**kwargs)
|
||||
else:
|
||||
to_ndc_transform = self.get_projection_transform(**kwargs)
|
||||
to_camera_transform = self.get_projection_transform(**kwargs)
|
||||
|
||||
unprojection_transform = to_ndc_transform.inverse()
|
||||
unprojection_transform = to_camera_transform.inverse()
|
||||
xy_inv_depth = torch.cat(
|
||||
(xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]), dim=-1 # type: ignore
|
||||
)
|
||||
return unprojection_transform.transform_points(xy_inv_depth)
|
||||
|
||||
def get_principal_point(self, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Return the camera's principal point
|
||||
|
||||
Args:
|
||||
**kwargs: parameters for the camera extrinsics can be passed in
|
||||
as keyword arguments to override the default values
|
||||
set in __init__.
|
||||
"""
|
||||
proj_mat = self.get_projection_transform(**kwargs).get_matrix()
|
||||
return proj_mat[:, 2, :2]
|
||||
|
||||
def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Returns the transform from camera projection space (screen or NDC) to NDC space.
|
||||
If the camera is defined already in NDC space, the transform is identity.
|
||||
For cameras defined in screen space, we adjust the principal point computation
|
||||
which is defined in the image space (commonly) and scale the points to NDC space.
|
||||
|
||||
Important: This transforms assumes PyTorch3D conventions for the input points,
|
||||
i.e. +X left, +Y up.
|
||||
"""
|
||||
if self.in_ndc():
|
||||
ndc_transform = Transform3d(device=self.device, dtype=torch.float32)
|
||||
else:
|
||||
# when cameras are defined in screen/image space, the principal point is
|
||||
# provided in the (+X right, +Y down), aka image, coordinate system.
|
||||
# Since input points are defined in the PyTorch3D system (+X left, +Y up),
|
||||
# we need to adjust for the principal point transform.
|
||||
pr_point_fix = torch.zeros(
|
||||
(self._N, 4, 4), device=self.device, dtype=torch.float32
|
||||
)
|
||||
pr_point_fix[:, 0, 0] = 1.0
|
||||
pr_point_fix[:, 1, 1] = 1.0
|
||||
pr_point_fix[:, 2, 2] = 1.0
|
||||
pr_point_fix[:, 3, 3] = 1.0
|
||||
pr_point_fix[:, :2, 3] = -2.0 * self.get_principal_point(**kwargs)
|
||||
pr_point_fix_transform = Transform3d(
|
||||
matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device
|
||||
)
|
||||
screen_to_ndc_transform = get_screen_to_ndc_transform(
|
||||
self, with_xyflip=False, **kwargs
|
||||
)
|
||||
ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform)
|
||||
|
||||
return ndc_transform
|
||||
|
||||
def is_perspective(self):
|
||||
return True
|
||||
|
||||
def in_ndc(self):
|
||||
return self._in_ndc
|
||||
|
||||
|
||||
def SfMOrthographicCameras(
|
||||
focal_length=1.0,
|
||||
@@ -989,29 +1091,8 @@ class OrthographicCameras(CamerasBase):
|
||||
transformation matrices using the multi-view geometry convention for
|
||||
orthographic camera.
|
||||
|
||||
Parameters for this camera can be specified in NDC or in screen space.
|
||||
If you wish to provide parameters in screen space, you NEED to provide
|
||||
the image_size = (imwidth, imheight).
|
||||
If you wish to provide parameters in NDC space, you should NOT provide
|
||||
image_size. Providing valid image_size will trigger a screen space to
|
||||
NDC space transformation in the camera.
|
||||
|
||||
For example, here is how to define cameras on the two spaces.
|
||||
|
||||
.. code-block:: python
|
||||
# camera defined in screen space
|
||||
cameras = OrthographicCameras(
|
||||
focal_length=((22.0, 15.0),), # (fx, fy)
|
||||
principal_point=((192.0, 128.0),), # (px, py)
|
||||
image_size=((256, 256),), # (imwidth, imheight)
|
||||
)
|
||||
|
||||
# the equivalent camera defined in NDC space
|
||||
cameras = OrthographicCameras(
|
||||
focal_length=((0.17875, 0.11718),), # := (fx / half_imwidth, fy / half_imheight)
|
||||
principal_point=((-0.5, 0),), # := (- (px - half_imwidth) / half_imwidth,
|
||||
- (py - half_imheight) / half_imheight)
|
||||
)
|
||||
Parameters for this camera are specified in NDC if `in_ndc` is set to True.
|
||||
If parameters are specified in screen space, `in_ndc` must be set to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -1022,7 +1103,8 @@ class OrthographicCameras(CamerasBase):
|
||||
T: torch.Tensor = _T,
|
||||
K: Optional[torch.Tensor] = None,
|
||||
device: Device = "cpu",
|
||||
image_size=((-1, -1),),
|
||||
in_ndc: bool = True,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@@ -1033,19 +1115,19 @@ class OrthographicCameras(CamerasBase):
|
||||
principal_point: xy coordinates of the center of
|
||||
the principal point of the camera in pixels.
|
||||
A tensor of shape (N, 2).
|
||||
in_ndc: True if camera parameters are specified in NDC.
|
||||
If False, then camera parameters are in screen space.
|
||||
R: Rotation matrix of shape (N, 3, 3)
|
||||
T: Translation matrix of shape (N, 3)
|
||||
K: (optional) A calibration matrix of shape (N, 4, 4)
|
||||
If provided, don't need focal_length, principal_point, image_size
|
||||
image_size: (height, width) of image size.
|
||||
A tensor of shape (N, 2). Required for screen cameras.
|
||||
device: torch.device or string
|
||||
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
|
||||
is provided, the camera parameters are assumed to be in screen
|
||||
space. They will be converted to NDC space.
|
||||
If image_size is not provided, the parameters are assumed to
|
||||
be in NDC space.
|
||||
"""
|
||||
# The initializer formats all inputs to torch tensors and broadcasts
|
||||
# all the inputs to have the same batch dimension where necessary.
|
||||
kwargs = {"image_size": image_size} if image_size is not None else {}
|
||||
super().__init__(
|
||||
device=device,
|
||||
focal_length=focal_length,
|
||||
@@ -1053,8 +1135,14 @@ class OrthographicCameras(CamerasBase):
|
||||
R=R,
|
||||
T=T,
|
||||
K=K,
|
||||
image_size=image_size,
|
||||
_in_ndc=in_ndc,
|
||||
**kwargs, # pyre-ignore
|
||||
)
|
||||
if image_size is not None:
|
||||
if (self.image_size < 1).any(): # pyre-ignore
|
||||
raise ValueError("Image_size provided has invalid values")
|
||||
else:
|
||||
self.image_size = None
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
@@ -1088,37 +1176,83 @@ class OrthographicCameras(CamerasBase):
|
||||
msg = "Expected K to have shape of (%r, 4, 4)"
|
||||
raise ValueError(msg % (self._N))
|
||||
else:
|
||||
image_size = kwargs.get("image_size", self.image_size)
|
||||
# if imwidth > 0, parameters are in screen space
|
||||
image_size = image_size if image_size[0][0] > 0 else None
|
||||
|
||||
K = _get_sfm_calibration_matrix(
|
||||
self._N,
|
||||
self.device,
|
||||
kwargs.get("focal_length", self.focal_length),
|
||||
kwargs.get("principal_point", self.principal_point),
|
||||
orthographic=True,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
transform = Transform3d(device=self.device)
|
||||
transform._matrix = K.transpose(1, 2).contiguous()
|
||||
transform = Transform3d(
|
||||
matrix=K.transpose(1, 2).contiguous(), device=self.device
|
||||
)
|
||||
return transform
|
||||
|
||||
def unproject_points(
|
||||
self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs
|
||||
) -> torch.Tensor:
|
||||
if world_coordinates:
|
||||
to_ndc_transform = self.get_full_projection_transform(**kwargs)
|
||||
to_camera_transform = self.get_full_projection_transform(**kwargs)
|
||||
else:
|
||||
to_ndc_transform = self.get_projection_transform(**kwargs)
|
||||
to_camera_transform = self.get_projection_transform(**kwargs)
|
||||
|
||||
unprojection_transform = to_ndc_transform.inverse()
|
||||
unprojection_transform = to_camera_transform.inverse()
|
||||
return unprojection_transform.transform_points(xy_depth)
|
||||
|
||||
def get_principal_point(self, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Return the camera's principal point
|
||||
|
||||
Args:
|
||||
**kwargs: parameters for the camera extrinsics can be passed in
|
||||
as keyword arguments to override the default values
|
||||
set in __init__.
|
||||
"""
|
||||
proj_mat = self.get_projection_transform(**kwargs).get_matrix()
|
||||
return proj_mat[:, 3, :2]
|
||||
|
||||
def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Returns the transform from camera projection space (screen or NDC) to NDC space.
|
||||
If the camera is defined already in NDC space, the transform is identity.
|
||||
For cameras defined in screen space, we adjust the principal point computation
|
||||
which is defined in the image space (commonly) and scale the points to NDC space.
|
||||
|
||||
Important: This transforms assumes PyTorch3D conventions for the input points,
|
||||
i.e. +X left, +Y up.
|
||||
"""
|
||||
if self.in_ndc():
|
||||
ndc_transform = Transform3d(device=self.device, dtype=torch.float32)
|
||||
else:
|
||||
# when cameras are defined in screen/image space, the principal point is
|
||||
# provided in the (+X right, +Y down), aka image, coordinate system.
|
||||
# Since input points are defined in the PyTorch3D system (+X left, +Y up),
|
||||
# we need to adjust for the principal point transform.
|
||||
pr_point_fix = torch.zeros(
|
||||
(self._N, 4, 4), device=self.device, dtype=torch.float32
|
||||
)
|
||||
pr_point_fix[:, 0, 0] = 1.0
|
||||
pr_point_fix[:, 1, 1] = 1.0
|
||||
pr_point_fix[:, 2, 2] = 1.0
|
||||
pr_point_fix[:, 3, 3] = 1.0
|
||||
pr_point_fix[:, :2, 3] = -2.0 * self.get_principal_point(**kwargs)
|
||||
pr_point_fix_transform = Transform3d(
|
||||
matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device
|
||||
)
|
||||
screen_to_ndc_transform = get_screen_to_ndc_transform(
|
||||
self, with_xyflip=False, **kwargs
|
||||
)
|
||||
ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform)
|
||||
|
||||
return ndc_transform
|
||||
|
||||
def is_perspective(self):
|
||||
return False
|
||||
|
||||
def in_ndc(self):
|
||||
return self._in_ndc
|
||||
|
||||
|
||||
################################################
|
||||
# Helper functions for cameras #
|
||||
@@ -1131,20 +1265,16 @@ def _get_sfm_calibration_matrix(
|
||||
focal_length,
|
||||
principal_point,
|
||||
orthographic: bool = False,
|
||||
image_size=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a calibration matrix of a perspective/orthographic camera.
|
||||
|
||||
Args:
|
||||
N: Number of cameras.
|
||||
focal_length: Focal length of the camera in world units.
|
||||
focal_length: Focal length of the camera.
|
||||
principal_point: xy coordinates of the center of
|
||||
the principal point of the camera in pixels.
|
||||
orthographic: Boolean specifying if the camera is orthographic or not
|
||||
image_size: (Optional) Specifying the image_size = (imwidth, imheight).
|
||||
If not None, the camera parameters are assumed to be in screen space
|
||||
and are transformed to NDC space.
|
||||
|
||||
The calibration matrix `K` is set up as follows:
|
||||
|
||||
@@ -1188,22 +1318,6 @@ def _get_sfm_calibration_matrix(
|
||||
|
||||
px, py = principal_point.unbind(1)
|
||||
|
||||
if image_size is not None:
|
||||
if not torch.is_tensor(image_size):
|
||||
image_size = torch.tensor(image_size, device=device)
|
||||
imwidth, imheight = image_size.unbind(1)
|
||||
# make sure imwidth, imheight are valid (>0)
|
||||
if (imwidth < 1).any() or (imheight < 1).any():
|
||||
raise ValueError(
|
||||
"Camera parameters provided in screen space. Image width or height invalid."
|
||||
)
|
||||
half_imwidth = imwidth / 2.0
|
||||
half_imheight = imheight / 2.0
|
||||
fx = fx / half_imwidth
|
||||
fy = fy / half_imheight
|
||||
px = -(px - half_imwidth) / half_imwidth
|
||||
py = -(py - half_imheight) / half_imheight
|
||||
|
||||
K = fx.new_zeros(N, 4, 4)
|
||||
K[:, 0, 0] = fx
|
||||
K[:, 1, 1] = fy
|
||||
@@ -1419,3 +1533,103 @@ def look_at_view_transform(
|
||||
R = look_at_rotation(C, at, up, device=device)
|
||||
T = -torch.bmm(R.transpose(1, 2), C[:, :, None])[:, :, 0]
|
||||
return R, T
|
||||
|
||||
|
||||
def get_ndc_to_screen_transform(
|
||||
cameras, with_xyflip: bool = False, **kwargs
|
||||
) -> Transform3d:
|
||||
"""
|
||||
PyTorch3D NDC to screen conversion.
|
||||
Conversion from PyTorch3D's NDC space (+X left, +Y up) to screen/image space
|
||||
(+X right, +Y down, origin top left).
|
||||
|
||||
Args:
|
||||
cameras
|
||||
with_xyflip: flips x- and y-axis if set to True.
|
||||
Optional kwargs:
|
||||
image_size: ((height, width),) specifying the height, width
|
||||
of the image. If not provided, it reads it from cameras.
|
||||
|
||||
We represent the NDC to screen conversion as a Transform3d
|
||||
with projection matrix
|
||||
|
||||
K = [
|
||||
[s, 0, 0, cx],
|
||||
[0, s, 0, cy],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
|
||||
"""
|
||||
# We require the image size, which is necessary for the transform
|
||||
image_size = kwargs.get("image_size", cameras.get_image_size())
|
||||
if image_size is None:
|
||||
msg = "For NDC to screen conversion, image_size=(height, width) needs to be specified."
|
||||
raise ValueError(msg)
|
||||
|
||||
K = torch.zeros((cameras._N, 4, 4), device=cameras.device, dtype=torch.float32)
|
||||
if not torch.is_tensor(image_size):
|
||||
image_size = torch.tensor(image_size, device=cameras.device)
|
||||
image_size = image_size.view(-1, 2) # of shape (1 or B)x2
|
||||
height, width = image_size.unbind(1)
|
||||
|
||||
# For non square images, we scale the points such that smallest side
|
||||
# has range [-1, 1] and the largest side has range [-u, u], with u > 1.
|
||||
# This convention is consistent with the PyTorch3D renderer
|
||||
scale = (image_size.min(dim=1).values - 1.0) / 2.0
|
||||
|
||||
K[:, 0, 0] = scale
|
||||
K[:, 1, 1] = scale
|
||||
K[:, 0, 3] = -1.0 * (width - 1.0) / 2.0
|
||||
K[:, 1, 3] = -1.0 * (height - 1.0) / 2.0
|
||||
K[:, 2, 2] = 1.0
|
||||
K[:, 3, 3] = 1.0
|
||||
|
||||
# Transpose the projection matrix as PyTorch3D transforms use row vectors.
|
||||
transform = Transform3d(
|
||||
matrix=K.transpose(1, 2).contiguous(), device=cameras.device
|
||||
)
|
||||
|
||||
if with_xyflip:
|
||||
# flip x, y axis
|
||||
xyflip = torch.eye(4, device=cameras.device, dtype=torch.float32)
|
||||
xyflip[0, 0] = -1.0
|
||||
xyflip[1, 1] = -1.0
|
||||
xyflip = xyflip.view(1, 4, 4).expand(cameras._N, -1, -1)
|
||||
xyflip_transform = Transform3d(
|
||||
matrix=xyflip.transpose(1, 2).contiguous(), device=cameras.device
|
||||
)
|
||||
transform = transform.compose(xyflip_transform)
|
||||
return transform
|
||||
|
||||
|
||||
def get_screen_to_ndc_transform(
|
||||
cameras, with_xyflip: bool = False, **kwargs
|
||||
) -> Transform3d:
|
||||
"""
|
||||
Screen to PyTorch3D NDC conversion.
|
||||
Conversion from screen/image space (+X right, +Y down, origin top left)
|
||||
to PyTorch3D's NDC space (+X left, +Y up).
|
||||
|
||||
Args:
|
||||
cameras
|
||||
with_xyflip: flips x- and y-axis if set to True.
|
||||
Optional kwargs:
|
||||
image_size: ((height, width),) specifying the height, width
|
||||
of the image. If not provided, it reads it from cameras.
|
||||
|
||||
We represent the screen to NDC conversion as a Transform3d
|
||||
with projection matrix
|
||||
|
||||
K = [
|
||||
[1/s, 0, 0, cx/s],
|
||||
[ 0, 1/s, 0, cy/s],
|
||||
[ 0, 0, 1, 0],
|
||||
[ 0, 0, 0, 1],
|
||||
]
|
||||
|
||||
"""
|
||||
transform = get_ndc_to_screen_transform(
|
||||
cameras, with_xyflip=with_xyflip, **kwargs
|
||||
).inverse()
|
||||
return transform
|
||||
|
||||
@@ -73,8 +73,7 @@ class MeshRasterizer(nn.Module):
|
||||
Args:
|
||||
cameras: A cameras object which has a `transform_points` method
|
||||
which returns the transformed points after applying the
|
||||
world-to-view and view-to-screen
|
||||
transformations.
|
||||
world-to-view and view-to-ndc transformations.
|
||||
raster_settings: the parameters for rasterization. This should be a
|
||||
named tuple.
|
||||
|
||||
@@ -100,8 +99,8 @@ class MeshRasterizer(nn.Module):
|
||||
vertex coordinates in world space.
|
||||
|
||||
Returns:
|
||||
meshes_screen: a Meshes object with the vertex positions in screen
|
||||
space
|
||||
meshes_proj: a Meshes object with the vertex positions projected
|
||||
in NDC space
|
||||
|
||||
NOTE: keeping this as a separate function for readability but it could
|
||||
be moved into forward.
|
||||
@@ -126,12 +125,14 @@ class MeshRasterizer(nn.Module):
|
||||
verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
|
||||
verts_world, eps=eps
|
||||
)
|
||||
verts_screen = cameras.get_projection_transform(**kwargs).transform_points(
|
||||
verts_view, eps=eps
|
||||
)
|
||||
verts_screen[..., 2] = verts_view[..., 2]
|
||||
meshes_screen = meshes_world.update_padded(new_verts_padded=verts_screen)
|
||||
return meshes_screen
|
||||
# view to NDC transform
|
||||
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
|
||||
projection_transform = cameras.get_projection_transform(**kwargs).compose(to_ndc_transform)
|
||||
verts_ndc = projection_transform.transform_points(verts_view, eps=eps)
|
||||
|
||||
verts_ndc[..., 2] = verts_view[..., 2]
|
||||
meshes_ndc = meshes_world.update_padded(new_verts_padded=verts_ndc)
|
||||
return meshes_ndc
|
||||
|
||||
def forward(self, meshes_world, **kwargs) -> Fragments:
|
||||
"""
|
||||
@@ -141,7 +142,7 @@ class MeshRasterizer(nn.Module):
|
||||
Returns:
|
||||
Fragments: Rasterization outputs as a named tuple.
|
||||
"""
|
||||
meshes_screen = self.transform(meshes_world, **kwargs)
|
||||
meshes_proj = self.transform(meshes_world, **kwargs)
|
||||
raster_settings = kwargs.get("raster_settings", self.raster_settings)
|
||||
|
||||
# By default, turn on clip_barycentric_coords if blur_radius > 0.
|
||||
@@ -166,7 +167,7 @@ class MeshRasterizer(nn.Module):
|
||||
z_clip = None if not perspective_correct or znear is None else znear / 2
|
||||
|
||||
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
||||
meshes_screen,
|
||||
meshes_proj,
|
||||
image_size=raster_settings.image_size,
|
||||
blur_radius=raster_settings.blur_radius,
|
||||
faces_per_pixel=raster_settings.faces_per_pixel,
|
||||
|
||||
@@ -55,8 +55,7 @@ class PointsRasterizer(nn.Module):
|
||||
"""
|
||||
cameras: A cameras object which has a `transform_points` method
|
||||
which returns the transformed points after applying the
|
||||
world-to-view and view-to-screen
|
||||
transformations.
|
||||
world-to-view and view-to-ndc transformations.
|
||||
raster_settings: the parameters for rasterization. This should be a
|
||||
named tuple.
|
||||
|
||||
@@ -76,8 +75,8 @@ class PointsRasterizer(nn.Module):
|
||||
point_clouds: a set of point clouds
|
||||
|
||||
Returns:
|
||||
points_screen: the points with the vertex positions in screen
|
||||
space
|
||||
points_proj: the points with positions projected
|
||||
in NDC space
|
||||
|
||||
NOTE: keeping this as a separate function for readability but it could
|
||||
be moved into forward.
|
||||
@@ -93,14 +92,17 @@ class PointsRasterizer(nn.Module):
|
||||
# TODO: Remove this line when the convention for the z coordinate in
|
||||
# the rasterizer is decided. i.e. retain z in view space or transform
|
||||
# to a different range.
|
||||
eps = kwargs.get("eps", None)
|
||||
pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
|
||||
pts_world
|
||||
pts_world, eps=eps
|
||||
)
|
||||
pts_screen = cameras.get_projection_transform(**kwargs).transform_points(
|
||||
pts_view
|
||||
)
|
||||
pts_screen[..., 2] = pts_view[..., 2]
|
||||
point_clouds = point_clouds.update_padded(pts_screen)
|
||||
# view to NDC transform
|
||||
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
|
||||
projection_transform = cameras.get_projection_transform(**kwargs).compose(to_ndc_transform)
|
||||
pts_ndc = projection_transform.transform_points(pts_view, eps=eps)
|
||||
|
||||
pts_ndc[..., 2] = pts_view[..., 2]
|
||||
point_clouds = point_clouds.update_padded(pts_ndc)
|
||||
return point_clouds
|
||||
|
||||
def to(self, device):
|
||||
@@ -115,10 +117,10 @@ class PointsRasterizer(nn.Module):
|
||||
Returns:
|
||||
PointFragments: Rasterization outputs as a named tuple.
|
||||
"""
|
||||
points_screen = self.transform(point_clouds, **kwargs)
|
||||
points_proj = self.transform(point_clouds, **kwargs)
|
||||
raster_settings = kwargs.get("raster_settings", self.raster_settings)
|
||||
idx, zbuf, dists2 = rasterize_points(
|
||||
points_screen,
|
||||
points_proj,
|
||||
image_size=raster_settings.image_size,
|
||||
radius=raster_settings.radius,
|
||||
points_per_pixel=raster_settings.points_per_pixel,
|
||||
|
||||
Reference in New Issue
Block a user