mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Add PyTorch3D->OpenCV camera parameter conversion.
Summary: This diff implements the inverse of D28992470 (8006842f2a
): a function to extract OpenCV convention camera parameters from a PyTorch3D `PerspectiveCameras` object. This is the first part of the new PyTorch3d<>OpenCV<>Pulsar conversion functions.
Reviewed By: patricklabatut
Differential Revision: D29278411
fbshipit-source-id: 68d4555b508dbe8685d8239443f839d194cc2484
This commit is contained in:
parent
e4039aa570
commit
da9974b416
@ -4,7 +4,10 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .camera_conversions import cameras_from_opencv_projection
|
||||
from .camera_conversions import (
|
||||
cameras_from_opencv_projection,
|
||||
opencv_from_cameras_projection,
|
||||
)
|
||||
from .ico_sphere import ico_sphere
|
||||
from .torus import torus
|
||||
|
||||
|
@ -4,10 +4,12 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..renderer import PerspectiveCameras
|
||||
from ..transforms import so3_exponential_map
|
||||
from ..transforms import so3_exponential_map, so3_log_map
|
||||
|
||||
|
||||
def cameras_from_opencv_projection(
|
||||
@ -35,7 +37,7 @@ def cameras_from_opencv_projection(
|
||||
followed by the homogenization of `x_screen_opencv`.
|
||||
|
||||
Note:
|
||||
The parameters `rvec, tvec, camera_matrix` correspond e.g. to the inputs
|
||||
The parameters `rvec, tvec, camera_matrix` correspond, e.g., to the inputs
|
||||
of `cv2.projectPoints`, or to the ouputs of `cv2.calibrateCamera`.
|
||||
|
||||
Args:
|
||||
@ -74,3 +76,51 @@ def cameras_from_opencv_projection(
|
||||
focal_length=focal_pytorch3d,
|
||||
principal_point=p0_pytorch3d,
|
||||
)
|
||||
|
||||
|
||||
def opencv_from_cameras_projection(
|
||||
cameras: PerspectiveCameras,
|
||||
image_size: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Converts a batch of `PerspectiveCameras` into OpenCV-convention
|
||||
axis-angle rotation vectors `rvec`, translation vectors `tvec`, and the camera
|
||||
calibration matrices `camera_matrix`. This operation is exactly the inverse
|
||||
of `cameras_from_opencv_projection`.
|
||||
|
||||
Note:
|
||||
The parameters `rvec, tvec, camera_matrix` correspond, e.g., to the inputs
|
||||
of `cv2.projectPoints`, or to the ouputs of `cv2.calibrateCamera`.
|
||||
|
||||
Args:
|
||||
cameras: A batch of `N` cameras in the PyTorch3D convention.
|
||||
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
|
||||
(height, width) attached to each camera.
|
||||
|
||||
Returns:
|
||||
rvec: A batch of axis-angle rotation vectors of shape `(N, 3)`.
|
||||
tvec: A batch of translation vectors of shape `(N, 3)`.
|
||||
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
|
||||
"""
|
||||
R_pytorch3d = cameras.R
|
||||
T_pytorch3d = cameras.T
|
||||
focal_pytorch3d = cameras.focal_length
|
||||
p0_pytorch3d = cameras.principal_point
|
||||
T_pytorch3d[:, :2] *= -1 # pyre-ignore
|
||||
R_pytorch3d[:, :, :2] *= -1 # pyre-ignore
|
||||
tvec = T_pytorch3d.clone() # pyre-ignore
|
||||
R = R_pytorch3d.permute(0, 2, 1) # pyre-ignore
|
||||
|
||||
# Retype the image_size correctly and flip to width, height.
|
||||
image_size_wh = image_size.to(R).flip(dims=(1,))
|
||||
|
||||
principal_point = (-p0_pytorch3d + 1.0) * (0.5 * image_size_wh) # pyre-ignore
|
||||
focal_length = focal_pytorch3d * (0.5 * image_size_wh)
|
||||
|
||||
camera_matrix = torch.zeros_like(R)
|
||||
camera_matrix[:, :2, 2] = principal_point
|
||||
camera_matrix[:, 2, 2] = 1.0
|
||||
camera_matrix[:, 0, 0] = focal_length[:, 0]
|
||||
camera_matrix[:, 1, 1] = focal_length[:, 1]
|
||||
rvec = so3_log_map(R)
|
||||
return rvec, tvec, camera_matrix
|
||||
|
@ -15,6 +15,7 @@ from pytorch3d.ops import eyes
|
||||
from pytorch3d.transforms import so3_exponential_map, so3_log_map
|
||||
from pytorch3d.utils import (
|
||||
cameras_from_opencv_projection,
|
||||
opencv_from_cameras_projection,
|
||||
)
|
||||
|
||||
DATA_DIR = get_tests_dir() / "data"
|
||||
@ -151,3 +152,11 @@ class TestCameraConversions(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(
|
||||
pts_proj_opencv_in_pytorch3d_screen, pts_proj_pytorch3d, atol=1e-5
|
||||
)
|
||||
|
||||
# Check the inverse.
|
||||
rvec_i, tvec_i, camera_matrix_i = opencv_from_cameras_projection(
|
||||
cameras_opencv_to_pytorch3d, image_size
|
||||
)
|
||||
self.assertClose(rvec, rvec_i)
|
||||
self.assertClose(tvec, tvec_i)
|
||||
self.assertClose(camera_matrix, camera_matrix_i)
|
||||
|
Loading…
x
Reference in New Issue
Block a user