Conversion from OpenCV cameras

Summary: Implements a conversion function between OpenCV and PyTorch3D cameras.

Reviewed By: patricklabatut

Differential Revision: D28992470

fbshipit-source-id: dbcc9f213ec293c2f6938261c704aea09aad3c90
This commit is contained in:
David Novotny 2021-06-21 05:02:46 -07:00 committed by Facebook GitHub Bot
parent b2ac2655b3
commit 8006842f2a
4 changed files with 1450 additions and 0 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .camera_conversions import cameras_from_opencv_projection
from .ico_sphere import ico_sphere from .ico_sphere import ico_sphere
from .torus import torus from .torus import torus

View File

@ -0,0 +1,70 @@
import torch
from ..renderer import PerspectiveCameras
from ..transforms import so3_exponential_map
def cameras_from_opencv_projection(
rvec: torch.Tensor,
tvec: torch.Tensor,
camera_matrix: torch.Tensor,
image_size: torch.Tensor,
) -> PerspectiveCameras:
"""
Converts a batch of OpenCV-conventioned cameras parametrized with the
axis-angle rotation vectors `rvec`, translation vectors `tvec`, and the camera
calibration matrices `camera_matrix` to `PerspectiveCameras` in PyTorch3D
convention.
More specifically, the conversion is carried out such that a projection
of a 3D shape to the OpenCV-conventioned screen of size `image_size` results
in the same image as a projection with the corresponding PyTorch3D camera
to the NDC screen convention of PyTorch3D.
More specifically, the OpenCV convention projects points to the OpenCV screen
space as follows:
```
x_screen_opencv = camera_matrix @ (exp(rvec) @ x_world + tvec)
```
followed by the homogenization of `x_screen_opencv`.
Note:
The parameters `rvec, tvec, camera_matrix` correspond e.g. to the inputs
of `cv2.projectPoints`, or to the ouputs of `cv2.calibrateCamera`.
Args:
rvec: A batch of axis-angle rotation vectors of shape `(N, 3)`.
tvec: A batch of translation vectors of shape `(N, 3)`.
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
(height, width) attached to each camera.
Returns:
cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention.
"""
R = so3_exponential_map(rvec)
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
principal_point = camera_matrix[:, :2, 2]
# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))
# Get the PyTorch3D focal length and principal point.
focal_pytorch3d = focal_length / (0.5 * image_size_wh)
p0_pytorch3d = -(principal_point / (0.5 * image_size_wh) - 1)
# For R, T we flip x, y axes (opencv screen space has an opposite
# orientation of screen axes).
# We also transpose R (opencv multiplies points from the opposite=left side).
R_pytorch3d = R.permute(0, 2, 1)
T_pytorch3d = tvec.clone()
R_pytorch3d[:, :, :2] *= -1
T_pytorch3d[:, :2] *= -1
return PerspectiveCameras(
R=R_pytorch3d,
T=T_pytorch3d,
focal_length=focal_pytorch3d,
principal_point=p0_pytorch3d,
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,149 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import json
import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin, get_tests_dir
from pytorch3d.ops import eyes
from pytorch3d.transforms import so3_exponential_map, so3_log_map
from pytorch3d.utils import (
cameras_from_opencv_projection,
)
DATA_DIR = get_tests_dir() / "data"
def _coords_opencv_screen_to_pytorch3d_ndc(xy_opencv, image_size):
"""
Converts the OpenCV screen coordinates `xy_opencv` to PyTorch3D NDC coordinates.
"""
xy_pytorch3d = -(2.0 * xy_opencv / image_size.flip(dims=(1,))[:, None] - 1.0)
return xy_pytorch3d
def cv2_project_points(pts, rvec, tvec, camera_matrix):
"""
Reproduces the `cv2.projectPoints` function from OpenCV using PyTorch.
"""
R = so3_exponential_map(rvec)
pts_proj_3d = (
camera_matrix.bmm(R.bmm(pts.permute(0, 2, 1)) + tvec[:, :, None])
).permute(0, 2, 1)
depth = pts_proj_3d[..., 2:]
pts_proj_2d = pts_proj_3d[..., :2] / depth
return pts_proj_2d
class TestCameraConversions(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
def test_cv2_project_points(self):
"""
Tests that the local implementation of cv2_project_points gives the same
restults OpenCV's `cv2.projectPoints`. The check is done against a set
of precomputed results `cv_project_points_precomputed`.
"""
with open(DATA_DIR / "cv_project_points_precomputed.json", "r") as f:
cv_project_points_precomputed = json.load(f)
for test_case in cv_project_points_precomputed:
_pts_proj = cv2_project_points(
**{
k: torch.tensor(test_case[k])[None]
for k in ("pts", "rvec", "tvec", "camera_matrix")
}
)
pts_proj = torch.tensor(test_case["pts_proj"])[None]
self.assertClose(_pts_proj, pts_proj, atol=1e-4)
def test_opencv_conversion(self):
"""
Tests that the cameras converted from opencv to pytorch3d convention
return correct projections of random 3D points. The check is done
against a set of results precomuted using `cv2.projectPoints` function.
"""
image_size = [[480, 640]] * 4
R = [
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
],
[
[1.0, 0.0, 0.0],
[0.0, 0.0, -1.0],
[0.0, 1.0, 0.0],
],
[
[0.0, 0.0, 1.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
],
[
[0.0, 0.0, 1.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
],
]
tvec = [
[0.0, 0.0, 3.0],
[0.3, -0.3, 3.0],
[-0.15, 0.1, 4.0],
[0.0, 0.0, 4.0],
]
focal_length = [
[100.0, 100.0],
[115.0, 115.0],
[105.0, 105.0],
[120.0, 120.0],
]
principal_point = [
[240, 320],
[240.5, 320.3],
[241, 318],
[242, 322],
]
principal_point, focal_length, R, tvec, image_size = [
torch.FloatTensor(x)
for x in (principal_point, focal_length, R, tvec, image_size)
]
camera_matrix = eyes(dim=3, N=4)
camera_matrix[:, 0, 0], camera_matrix[:, 1, 1] = (
focal_length[:, 0],
focal_length[:, 1],
)
camera_matrix[:, :2, 2] = principal_point
rvec = so3_log_map(R)
pts = torch.nn.functional.normalize(torch.randn(4, 1000, 3), dim=-1)
# project the 3D points with the opencv projection function
pts_proj_opencv = cv2_project_points(pts, rvec, tvec, camera_matrix)
# make the pytorch3d cameras
cameras_opencv_to_pytorch3d = cameras_from_opencv_projection(
rvec, tvec, camera_matrix, image_size
)
# project the 3D points with converted cameras
pts_proj_pytorch3d = cameras_opencv_to_pytorch3d.transform_points(pts)[..., :2]
# convert the opencv-projected points to pytorch3d screen coords
pts_proj_opencv_in_pytorch3d_screen = _coords_opencv_screen_to_pytorch3d_ndc(
pts_proj_opencv, image_size
)
# compare to the cached projected points
self.assertClose(
pts_proj_opencv_in_pytorch3d_screen, pts_proj_pytorch3d, atol=1e-5
)