Update cameras to accept projection matrix as input

Summary: To initialize the Cameras class currently we require the principal point, focal length and other parameters to be specified from which we calculate the intrinsic matrix. In some cases the matrix might be directly available e.g. from a dataset and the associated metadata for an image.

Reviewed By: nikhilaravi

Differential Revision: D24489509

fbshipit-source-id: 1b411f19c5f6c8074bcfbf613f3339d5e242c119
This commit is contained in:
Dave Schnizlein 2020-10-30 08:52:13 -07:00 committed by Facebook GitHub Bot
parent 6f4697bc1b
commit 36fb257ef1
3 changed files with 210 additions and 133 deletions

View File

@ -70,7 +70,7 @@ class CamerasBase(TensorProperties):
arguments to override the default values set in `__init__`.
Return:
P: a `Transform3d` object which represents a batch of projection
a `Transform3d` object which represents a batch of projection
matrices of shape (N, 3, 3)
"""
raise NotImplementedError()
@ -333,10 +333,10 @@ class FoVPerspectiveCameras(CamerasBase):
degrees: bool = True,
R=_R,
T=_T,
K=None,
device="cpu",
):
"""
__init__(self, znear, zfar, aspect_ratio, fov, degrees, R, T, device) -> None # noqa
Args:
znear: near clipping plane of the view frustrum.
@ -346,6 +346,8 @@ class FoVPerspectiveCameras(CamerasBase):
degrees: bool, set to True if fov is specified in degrees.
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
K: (optional) A calibration matrix of shape (N, 4, 4)
If provided, don't need znear, zfar, fov, aspect_ratio, degrees
device: torch.device or string
"""
# The initializer formats all inputs to torch tensors and broadcasts
@ -358,55 +360,29 @@ class FoVPerspectiveCameras(CamerasBase):
fov=fov,
R=R,
T=T,
K=K,
)
# No need to convert to tensor or broadcast.
self.degrees = degrees
def get_projection_transform(self, **kwargs) -> Transform3d:
def compute_projection_matrix(
self, znear, zfar, fov, aspect_ratio, degrees
) -> torch.Tensor:
"""
Calculate the perpective projection matrix with a symmetric
viewing frustrum. Use column major order.
The viewing frustrum will be projected into ndc, s.t.
(max_x, max_y) -> (+1, +1)
(min_x, min_y) -> (-1, -1)
Compute the calibration matrix K of shape (N, 4, 4)
Args:
**kwargs: parameters for the projection can be passed in as keyword
arguments to override the default values set in `__init__`.
znear: near clipping plane of the view frustrum.
zfar: far clipping plane of the view frustrum.
fov: field of view angle of the camera.
aspect_ratio: ratio of screen_width/screen_height.
degrees: bool, set to True if fov is specified in degrees.
Return:
P: a Transform3d object which represents a batch of projection
matrices of shape (N, 4, 4)
.. code-block:: python
h1 = (max_y + min_y)/(max_y - min_y)
w1 = (max_x + min_x)/(max_x - min_x)
tanhalffov = tan((fov/2))
s1 = 1/tanhalffov
s2 = 1/(tanhalffov * (aspect_ratio))
# To map z to the range [0, 1] use:
f1 = far / (far - near)
f2 = -(far * near) / (far - near)
# Projection matrix
P = [
[s1, 0, w1, 0],
[0, s2, h1, 0],
[0, 0, f1, f2],
[0, 0, 1, 0],
]
Returns:
torch.floatTensor of the calibration matrix with shape (N, 4, 4)
"""
znear = kwargs.get("znear", self.znear) # pyre-ignore[16]
zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16]
fov = kwargs.get("fov", self.fov) # pyre-ignore[16]
# pyre-ignore[16]
aspect_ratio = kwargs.get("aspect_ratio", self.aspect_ratio)
degrees = kwargs.get("degrees", self.degrees)
P = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32)
K = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32)
ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
if degrees:
fov = (np.pi / 180) * fov
@ -427,21 +403,73 @@ class FoVPerspectiveCameras(CamerasBase):
# so the so the z sign is 1.0.
z_sign = 1.0
P[:, 0, 0] = 2.0 * znear / (max_x - min_x)
P[:, 1, 1] = 2.0 * znear / (max_y - min_y)
P[:, 0, 2] = (max_x + min_x) / (max_x - min_x)
P[:, 1, 2] = (max_y + min_y) / (max_y - min_y)
P[:, 3, 2] = z_sign * ones
K[:, 0, 0] = 2.0 * znear / (max_x - min_x)
K[:, 1, 1] = 2.0 * znear / (max_y - min_y)
K[:, 0, 2] = (max_x + min_x) / (max_x - min_x)
K[:, 1, 2] = (max_y + min_y) / (max_y - min_y)
K[:, 3, 2] = z_sign * ones
# NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
# is at the near clipping plane and z = 1 when the point is at the far
# clipping plane.
P[:, 2, 2] = z_sign * zfar / (zfar - znear)
P[:, 2, 3] = -(zfar * znear) / (zfar - znear)
K[:, 2, 2] = z_sign * zfar / (zfar - znear)
K[:, 2, 3] = -(zfar * znear) / (zfar - znear)
return K
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the perpective projection matrix with a symmetric
viewing frustrum. Use column major order.
The viewing frustrum will be projected into ndc, s.t.
(max_x, max_y) -> (+1, +1)
(min_x, min_y) -> (-1, -1)
Args:
**kwargs: parameters for the projection can be passed in as keyword
arguments to override the default values set in `__init__`.
Return:
a Transform3d object which represents a batch of projection
matrices of shape (N, 4, 4)
.. code-block:: python
h1 = (max_y + min_y)/(max_y - min_y)
w1 = (max_x + min_x)/(max_x - min_x)
tanhalffov = tan((fov/2))
s1 = 1/tanhalffov
s2 = 1/(tanhalffov * (aspect_ratio))
# To map z to the range [0, 1] use:
f1 = far / (far - near)
f2 = -(far * near) / (far - near)
# Projection matrix
K = [
[s1, 0, w1, 0],
[0, s2, h1, 0],
[0, 0, f1, f2],
[0, 0, 1, 0],
]
"""
K = kwargs.get("K", self.K) # pyre-ignore[16]
if K is not None:
if K.shape != (self._N, 4, 4):
msg = "Expected K to have shape of (%r, 4, 4)"
raise ValueError(msg % (self._N))
else:
K = self.compute_projection_matrix(
kwargs.get("znear", self.znear), # pyre-ignore[16]
kwargs.get("zfar", self.zfar), # pyre-ignore[16]
kwargs.get("fov", self.fov), # pyre-ignore[16]
kwargs.get("aspect_ratio", self.aspect_ratio), # pyre-ignore[16]
kwargs.get("degrees", self.degrees),
)
# Transpose the projection matrix as PyTorch3D transforms use row vectors.
transform = Transform3d(device=self.device)
transform._matrix = P.transpose(1, 2).contiguous()
transform._matrix = K.transpose(1, 2).contiguous()
return transform
def unproject_points(
@ -473,12 +501,12 @@ class FoVPerspectiveCameras(CamerasBase):
xy_sdepth = xy_depth
else:
# parse out important values from the projection matrix
P_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix()
# parse out f1, f2 from P_matrix
K_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix()
# parse out f1, f2 from K_matrix
unsqueeze_shape = [1] * xy_depth.dim()
unsqueeze_shape[0] = P_matrix.shape[0]
f1 = P_matrix[:, 2, 2].reshape(unsqueeze_shape)
f2 = P_matrix[:, 3, 2].reshape(unsqueeze_shape)
unsqueeze_shape[0] = K_matrix.shape[0]
f1 = K_matrix[:, 2, 2].reshape(unsqueeze_shape)
f2 = K_matrix[:, 3, 2].reshape(unsqueeze_shape)
# get the scaled depth
sdepth = (f1 * xy_depth[..., 2:3] + f2) / xy_depth[..., 2:3]
# concatenate xy + scaled depth
@ -545,10 +573,10 @@ class FoVOrthographicCameras(CamerasBase):
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
R=_R,
T=_T,
K=None,
device="cpu",
):
"""
__init__(self, znear, zfar, max_y, min_y, max_x, min_x, scale_xyz, R, T, device) -> None # noqa
Args:
znear: near clipping plane of the view frustrum.
@ -560,6 +588,8 @@ class FoVOrthographicCameras(CamerasBase):
scale_xyz: scale factors for each axis of shape (N, 3).
R: Rotation matrix of shape (N, 3, 3).
T: Translation of shape (N, 3).
K: (optional) A calibration matrix of shape (N, 4, 4)
If provided, don't need znear, zfar, max_y, min_y, max_x, min_x, scale_xyz
device: torch.device or string.
Only need to set min_x, max_x, min_y, max_y for viewing frustrums
@ -578,8 +608,44 @@ class FoVOrthographicCameras(CamerasBase):
scale_xyz=scale_xyz,
R=R,
T=T,
K=K,
)
def compute_projection_matrix(
self, znear, zfar, max_x, min_x, max_y, min_y, scale_xyz
) -> torch.Tensor:
"""
Compute the calibration matrix K of shape (N, 4, 4)
Args:
znear: near clipping plane of the view frustrum.
zfar: far clipping plane of the view frustrum.
max_x: maximum x coordinate of the frustrum.
min_x: minumum x coordinage of the frustrum
max_y: maximum y coordinate of the frustrum.
min_y: minimum y coordinate of the frustrum.
scale_xyz: scale factors for each axis of shape (N, 3).
"""
K = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device)
ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
# NOTE: OpenGL flips handedness of coordinate system between camera
# space and NDC space so z sign is -ve. In PyTorch3D we maintain a
# right handed coordinate system throughout.
z_sign = +1.0
K[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
K[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
K[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
K[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
K[:, 3, 3] = ones
# NOTE: This maps the z coordinate to the range [0, 1] and replaces the
# the OpenGL z normalization to [-1, 1]
K[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
K[:, 2, 3] = -znear / (zfar - znear)
return K
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the orthographic projection matrix.
@ -589,7 +655,7 @@ class FoVOrthographicCameras(CamerasBase):
**kwargs: parameters for the projection can be passed in to
override the default values set in __init__.
Return:
P: a Transform3d object which represents a batch of projection
a Transform3d object which represents a batch of projection
matrices of shape (N, 4, 4)
.. code-block:: python
@ -601,41 +667,31 @@ class FoVOrthographicCameras(CamerasBase):
mix_y = (max_y + min_y) / (max_y - min_y)
mid_z = (far + near) / (farnear)
P = [
K = [
[scale_x, 0, 0, -mid_x],
[0, scale_y, 0, -mix_y],
[0, 0, -scale_z, -mid_z],
[0, 0, 0, 1],
]
"""
znear = kwargs.get("znear", self.znear) # pyre-ignore[16]
zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16]
max_x = kwargs.get("max_x", self.max_x) # pyre-ignore[16]
min_x = kwargs.get("min_x", self.min_x) # pyre-ignore[16]
max_y = kwargs.get("max_y", self.max_y) # pyre-ignore[16]
min_y = kwargs.get("min_y", self.min_y) # pyre-ignore[16]
scale_xyz = kwargs.get("scale_xyz", self.scale_xyz) # pyre-ignore[16]
P = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device)
ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
# NOTE: OpenGL flips handedness of coordinate system between camera
# space and NDC space so z sign is -ve. In PyTorch3D we maintain a
# right handed coordinate system throughout.
z_sign = +1.0
P[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
P[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
P[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
P[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
P[:, 3, 3] = ones
# NOTE: This maps the z coordinate to the range [0, 1] and replaces the
# the OpenGL z normalization to [-1, 1]
P[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
P[:, 2, 3] = -znear / (zfar - znear)
K = kwargs.get("K", self.K) # pyre-ignore[16]
if K is not None:
if K.shape != (self._N, 4, 4):
msg = "Expected K to have shape of (%r, 4, 4)"
raise ValueError(msg % (self._N))
else:
K = self.compute_projection_matrix(
kwargs.get("znear", self.znear), # pyre-ignore[16]
kwargs.get("zfar", self.zfar), # pyre-ignore[16]
kwargs.get("max_x", self.max_x), # pyre-ignore[16]
kwargs.get("min_x", self.min_x), # pyre-ignore[16]
kwargs.get("max_y", self.max_y), # pyre-ignore[16]
kwargs.get("min_y", self.min_y), # pyre-ignore[16]
kwargs.get("scale_xyz", self.scale_xyz), # pyre-ignore[16]
)
transform = Transform3d(device=self.device)
transform._matrix = P.transpose(1, 2).contiguous()
transform._matrix = K.transpose(1, 2).contiguous()
return transform
def unproject_points(
@ -666,11 +722,11 @@ class FoVOrthographicCameras(CamerasBase):
xy_sdepth = xy_depth
else:
# we have to obtain the scaled depth first
P = self.get_projection_transform(**kwargs).get_matrix()
unsqueeze_shape = [1] * P.dim()
unsqueeze_shape[0] = P.shape[0]
mid_z = P[:, 3, 2].reshape(unsqueeze_shape)
scale_z = P[:, 2, 2].reshape(unsqueeze_shape)
K = self.get_projection_transform(**kwargs).get_matrix()
unsqueeze_shape = [1] * K.dim()
unsqueeze_shape[0] = K.shape[0]
mid_z = K[:, 3, 2].reshape(unsqueeze_shape)
scale_z = K[:, 2, 2].reshape(unsqueeze_shape)
scaled_depth = scale_z * xy_depth[..., 2:3] + mid_z
# cat xy and scaled depth
xy_sdepth = torch.cat((xy_depth[..., :2], scaled_depth), dim=-1)
@ -752,11 +808,11 @@ class PerspectiveCameras(CamerasBase):
principal_point=((0.0, 0.0),),
R=_R,
T=_T,
K=None,
device="cpu",
image_size=((-1, -1),),
):
"""
__init__(self, focal_length, principal_point, R, T, device, image_size) -> None
Args:
focal_length: Focal length of the camera in world units.
@ -767,6 +823,9 @@ class PerspectiveCameras(CamerasBase):
A tensor of shape (N, 2).
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
K: (optional) A calibration matrix of shape (N, 4, 4)
If provided, don't need focal_length, principal_point, image_size
device: torch.device or string
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
is provided, the camera parameters are assumed to be in screen
@ -782,6 +841,7 @@ class PerspectiveCameras(CamerasBase):
principal_point=principal_point,
R=R,
T=T,
K=K,
image_size=image_size,
)
@ -795,7 +855,7 @@ class PerspectiveCameras(CamerasBase):
arguments to override the default values set in __init__.
Returns:
P: A `Transform3d` object with a batch of `N` projection transforms.
A `Transform3d` object with a batch of `N` projection transforms.
.. code-block:: python
@ -804,35 +864,35 @@ class PerspectiveCameras(CamerasBase):
px = principal_point[:, 0]
py = principal_point[:, 1]
P = [
K = [
[fx, 0, px, 0],
[0, fy, py, 0],
[0, 0, 0, 1],
[0, 0, 1, 0],
]
"""
# pyre-ignore[16]
principal_point = kwargs.get("principal_point", self.principal_point)
# pyre-ignore[16]
focal_length = kwargs.get("focal_length", self.focal_length)
# pyre-ignore[16]
image_size = kwargs.get("image_size", self.image_size)
K = kwargs.get("K", self.K) # pyre-ignore[16]
if K is not None:
if K.shape != (self._N, 4, 4):
msg = "Expected K to have shape of (%r, 4, 4)"
raise ValueError(msg % (self._N))
else:
# pyre-ignore[16]
image_size = kwargs.get("image_size", self.image_size)
# if imwidth > 0, parameters are in screen space
image_size = image_size if image_size[0][0] > 0 else None
# if imwidth > 0, parameters are in screen space
in_screen = image_size[0][0] > 0
image_size = image_size if in_screen else None
P = _get_sfm_calibration_matrix(
self._N,
self.device,
focal_length,
principal_point,
orthographic=False,
image_size=image_size,
)
K = _get_sfm_calibration_matrix(
self._N,
self.device,
kwargs.get("focal_length", self.focal_length), # pyre-ignore[16]
kwargs.get("principal_point", self.principal_point), # pyre-ignore[16]
orthographic=False,
image_size=image_size,
)
transform = Transform3d(device=self.device)
transform._matrix = P.transpose(1, 2).contiguous()
transform._matrix = K.transpose(1, 2).contiguous()
return transform
def unproject_points(
@ -911,11 +971,11 @@ class OrthographicCameras(CamerasBase):
principal_point=((0.0, 0.0),),
R=_R,
T=_T,
K=None,
device="cpu",
image_size=((-1, -1),),
):
"""
__init__(self, focal_length, principal_point, R, T, device, image_size) -> None
Args:
focal_length: Focal length of the camera in world units.
@ -926,6 +986,8 @@ class OrthographicCameras(CamerasBase):
A tensor of shape (N, 2).
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
K: (optional) A calibration matrix of shape (N, 4, 4)
If provided, don't need focal_length, principal_point, image_size
device: torch.device or string
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
is provided, the camera parameters are assumed to be in screen
@ -941,6 +1003,7 @@ class OrthographicCameras(CamerasBase):
principal_point=principal_point,
R=R,
T=T,
K=K,
image_size=image_size,
)
@ -954,7 +1017,7 @@ class OrthographicCameras(CamerasBase):
arguments to override the default values set in __init__.
Returns:
P: A `Transform3d` object with a batch of `N` projection transforms.
A `Transform3d` object with a batch of `N` projection transforms.
.. code-block:: python
@ -963,35 +1026,35 @@ class OrthographicCameras(CamerasBase):
px = principal_point[:,0]
py = principal_point[:,1]
P = [
K = [
[fx, 0, 0, px],
[0, fy, 0, py],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
"""
# pyre-ignore[16]
principal_point = kwargs.get("principal_point", self.principal_point)
# pyre-ignore[16]
focal_length = kwargs.get("focal_length", self.focal_length)
# pyre-ignore[16]
image_size = kwargs.get("image_size", self.image_size)
K = kwargs.get("K", self.K) # pyre-ignore[16]
if K is not None:
if K.shape != (self._N, 4, 4):
msg = "Expected K to have shape of (%r, 4, 4)"
raise ValueError(msg % (self._N))
else:
# pyre-ignore[16]
image_size = kwargs.get("image_size", self.image_size)
# if imwidth > 0, parameters are in screen space
image_size = image_size if image_size[0][0] > 0 else None
# if imwidth > 0, parameters are in screen space
in_screen = image_size[0][0] > 0
image_size = image_size if in_screen else None
P = _get_sfm_calibration_matrix(
self._N,
self.device,
focal_length,
principal_point,
orthographic=True,
image_size=image_size,
)
K = _get_sfm_calibration_matrix(
self._N,
self.device,
kwargs.get("focal_length", self.focal_length), # pyre-ignore[16]
kwargs.get("principal_point", self.principal_point), # pyre-ignore[16]
orthographic=True,
image_size=image_size,
)
transform = Transform3d(device=self.device)
transform._matrix = P.transpose(1, 2).contiguous()
transform._matrix = K.transpose(1, 2).contiguous()
return transform
def unproject_points(

View File

@ -106,7 +106,7 @@ class TensorProperties(nn.Module):
# set as attributes anything else e.g. strings, bools
args_to_broadcast = {}
for k, v in kwargs.items():
if isinstance(v, (str, bool)):
if v is None or isinstance(v, (str, bool)):
setattr(self, k, v)
elif isinstance(v, BROADCAST_TYPES):
args_to_broadcast[k] = v

View File

@ -449,6 +449,20 @@ class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
def test_K(self, batch_size=10):
T = torch.randn(batch_size, 3)
R = random_rotations(batch_size)
K = torch.randn(batch_size, 4, 4)
for cam_type in (
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
):
cam = cam_type(R=R, T=T, K=K)
cam.get_projection_transform()
# Just checking that we don't crash or anything
def test_view_transform_class_method(self):
T = torch.tensor([0.0, 0.0, -1.0], requires_grad=True).view(1, -1)
R = look_at_rotation(T)