Add PyTorch3D->OpenCV camera parameter conversion.

Summary: This diff implements the inverse of D28992470 (8006842f2a): a function to extract OpenCV convention camera parameters from a PyTorch3D `PerspectiveCameras` object. This is the first part of the new PyTorch3d<>OpenCV<>Pulsar conversion functions.

Reviewed By: patricklabatut

Differential Revision: D29278411

fbshipit-source-id: 68d4555b508dbe8685d8239443f839d194cc2484
This commit is contained in:
Christoph Lassner 2021-06-23 14:37:10 -07:00 committed by Facebook GitHub Bot
parent e4039aa570
commit da9974b416
3 changed files with 65 additions and 3 deletions

View File

@ -4,7 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .camera_conversions import cameras_from_opencv_projection
from .camera_conversions import (
cameras_from_opencv_projection,
opencv_from_cameras_projection,
)
from .ico_sphere import ico_sphere
from .torus import torus

View File

@ -4,10 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch
from ..renderer import PerspectiveCameras
from ..transforms import so3_exponential_map
from ..transforms import so3_exponential_map, so3_log_map
def cameras_from_opencv_projection(
@ -35,7 +37,7 @@ def cameras_from_opencv_projection(
followed by the homogenization of `x_screen_opencv`.
Note:
The parameters `rvec, tvec, camera_matrix` correspond e.g. to the inputs
The parameters `rvec, tvec, camera_matrix` correspond, e.g., to the inputs
of `cv2.projectPoints`, or to the ouputs of `cv2.calibrateCamera`.
Args:
@ -74,3 +76,51 @@ def cameras_from_opencv_projection(
focal_length=focal_pytorch3d,
principal_point=p0_pytorch3d,
)
def opencv_from_cameras_projection(
cameras: PerspectiveCameras,
image_size: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Converts a batch of `PerspectiveCameras` into OpenCV-convention
axis-angle rotation vectors `rvec`, translation vectors `tvec`, and the camera
calibration matrices `camera_matrix`. This operation is exactly the inverse
of `cameras_from_opencv_projection`.
Note:
The parameters `rvec, tvec, camera_matrix` correspond, e.g., to the inputs
of `cv2.projectPoints`, or to the ouputs of `cv2.calibrateCamera`.
Args:
cameras: A batch of `N` cameras in the PyTorch3D convention.
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
(height, width) attached to each camera.
Returns:
rvec: A batch of axis-angle rotation vectors of shape `(N, 3)`.
tvec: A batch of translation vectors of shape `(N, 3)`.
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
"""
R_pytorch3d = cameras.R
T_pytorch3d = cameras.T
focal_pytorch3d = cameras.focal_length
p0_pytorch3d = cameras.principal_point
T_pytorch3d[:, :2] *= -1 # pyre-ignore
R_pytorch3d[:, :, :2] *= -1 # pyre-ignore
tvec = T_pytorch3d.clone() # pyre-ignore
R = R_pytorch3d.permute(0, 2, 1) # pyre-ignore
# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))
principal_point = (-p0_pytorch3d + 1.0) * (0.5 * image_size_wh) # pyre-ignore
focal_length = focal_pytorch3d * (0.5 * image_size_wh)
camera_matrix = torch.zeros_like(R)
camera_matrix[:, :2, 2] = principal_point
camera_matrix[:, 2, 2] = 1.0
camera_matrix[:, 0, 0] = focal_length[:, 0]
camera_matrix[:, 1, 1] = focal_length[:, 1]
rvec = so3_log_map(R)
return rvec, tvec, camera_matrix

View File

@ -15,6 +15,7 @@ from pytorch3d.ops import eyes
from pytorch3d.transforms import so3_exponential_map, so3_log_map
from pytorch3d.utils import (
cameras_from_opencv_projection,
opencv_from_cameras_projection,
)
DATA_DIR = get_tests_dir() / "data"
@ -151,3 +152,11 @@ class TestCameraConversions(TestCaseMixin, unittest.TestCase):
self.assertClose(
pts_proj_opencv_in_pytorch3d_screen, pts_proj_pytorch3d, atol=1e-5
)
# Check the inverse.
rvec_i, tvec_i, camera_matrix_i = opencv_from_cameras_projection(
cameras_opencv_to_pytorch3d, image_size
)
self.assertClose(rvec, rvec_i)
self.assertClose(tvec, tvec_i)
self.assertClose(camera_matrix, camera_matrix_i)