Add OpenCV camera conversion; fix bug for camera unified PyTorch3D interface.

Summary: This commit adds a new camera conversion function for OpenCV style parameters to Pulsar parameters to the library. Using this function it addresses a bug reported here: https://fb.workplace.com/groups/629644647557365/posts/1079637302558095, by using the PyTorch3D->OpenCV->Pulsar chain instead of the original direct conversion function. Both conversions are well-tested and an additional test for the full chain has been added, resulting in a more reliable solution requiring less code.

Reviewed By: patricklabatut

Differential Revision: D29322106

fbshipit-source-id: 13df13c2e48f628f75d9f44f19ff7f1646fb7ebd
This commit is contained in:
Christoph Lassner 2021-07-10 01:05:36 -07:00 committed by Facebook GitHub Bot
parent fef5bcd8f9
commit 75432a0695
8 changed files with 275 additions and 32 deletions

View File

@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ....transforms import matrix_to_rotation_6d from ....utils import pulsar_from_cameras_projection
from ...cameras import ( from ...cameras import (
FoVOrthographicCameras, FoVOrthographicCameras,
FoVPerspectiveCameras, FoVPerspectiveCameras,
@ -102,7 +102,7 @@ class PulsarPointsRenderer(nn.Module):
height=height, height=height,
max_num_balls=max_num_spheres, max_num_balls=max_num_spheres,
orthogonal_projection=orthogonal_projection, orthogonal_projection=orthogonal_projection,
right_handed_system=True, right_handed_system=False,
n_channels=n_channels, n_channels=n_channels,
**kwargs, **kwargs,
) )
@ -359,24 +359,28 @@ class PulsarPointsRenderer(nn.Module):
def _extract_extrinsics( def _extract_extrinsics(
self, kwargs, cloud_idx self, kwargs, cloud_idx
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Extract the extrinsic information from the kwargs for a specific point cloud.
Instead of implementing a direct translation from the PyTorch3D to the Pulsar
camera model, we chain the two conversions of PyTorch3D->OpenCV and
OpenCV->Pulsar for better maintainability (PyTorch3D->OpenCV is maintained and
tested by the core PyTorch3D team, whereas OpenCV->Pulsar is maintained and
tested by the Pulsar team).
"""
# Shorthand: # Shorthand:
cameras = self.rasterizer.cameras cameras = self.rasterizer.cameras
R = kwargs.get("R", cameras.R)[cloud_idx] R = kwargs.get("R", cameras.R)[cloud_idx]
T = kwargs.get("T", cameras.T)[cloud_idx] T = kwargs.get("T", cameras.T)[cloud_idx]
norm_mat = torch.tensor( tmp_cams = PerspectiveCameras(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]], R=R.unsqueeze(0), T=T.unsqueeze(0), device=R.device
dtype=torch.float32,
device=R.device,
) )
cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...]).permute((0, 2, 1)) size_tensor = torch.tensor(
norm_mat = torch.tensor( [[self.renderer._renderer.height, self.renderer._renderer.width]]
[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=torch.float32,
device=R.device,
) )
cam_rot = torch.matmul(norm_mat, cam_rot) pulsar_cam = pulsar_from_cameras_projection(tmp_cams, size_tensor)
cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None])) cam_pos = pulsar_cam[0, :3]
cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot)) cam_rot = pulsar_cam[0, 3:9]
return cam_pos, cam_rot return cam_pos, cam_rot
def _get_vert_rad( def _get_vert_rad(
@ -547,15 +551,17 @@ class PulsarPointsRenderer(nn.Module):
otherargs["bg_col"] = bg_col otherargs["bg_col"] = bg_col
# Go! # Go!
images.append( images.append(
self.renderer( torch.flipud(
vert_pos=vert_pos, self.renderer(
vert_col=vert_col, vert_pos=vert_pos,
vert_rad=vert_rad, vert_col=vert_col,
cam_params=cam_params, vert_rad=vert_rad,
gamma=gamma, cam_params=cam_params,
max_depth=zfar, gamma=gamma,
min_depth=znear, max_depth=zfar,
**otherargs, min_depth=znear,
**otherargs,
)
) )
) )
return torch.stack(images, dim=0) return torch.stack(images, dim=0)

View File

@ -7,6 +7,8 @@
from .camera_conversions import ( from .camera_conversions import (
cameras_from_opencv_projection, cameras_from_opencv_projection,
opencv_from_cameras_projection, opencv_from_cameras_projection,
pulsar_from_opencv_projection,
pulsar_from_cameras_projection,
) )
from .ico_sphere import ico_sphere from .ico_sphere import ico_sphere
from .torus import torus from .torus import torus

View File

@ -4,12 +4,16 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
from typing import Tuple from typing import Tuple
import torch import torch
from ..renderer import PerspectiveCameras from ..renderer import PerspectiveCameras
from ..transforms import so3_exp_map, so3_log_map from ..transforms import matrix_to_rotation_6d
LOGGER = logging.getLogger(__name__)
def cameras_from_opencv_projection( def cameras_from_opencv_projection(
@ -54,7 +58,6 @@ def cameras_from_opencv_projection(
Returns: Returns:
cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention. cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention.
""" """
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1) focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
principal_point = camera_matrix[:, :2, 2] principal_point = camera_matrix[:, :2, 2]
@ -68,7 +71,7 @@ def cameras_from_opencv_projection(
# For R, T we flip x, y axes (opencv screen space has an opposite # For R, T we flip x, y axes (opencv screen space has an opposite
# orientation of screen axes). # orientation of screen axes).
# We also transpose R (opencv multiplies points from the opposite=left side). # We also transpose R (opencv multiplies points from the opposite=left side).
R_pytorch3d = R.permute(0, 2, 1) R_pytorch3d = R.clone().permute(0, 2, 1)
T_pytorch3d = tvec.clone() T_pytorch3d = tvec.clone()
R_pytorch3d[:, :, :2] *= -1 R_pytorch3d[:, :, :2] *= -1
T_pytorch3d[:, :2] *= -1 T_pytorch3d[:, :2] *= -1
@ -103,20 +106,22 @@ def opencv_from_cameras_projection(
cameras: A batch of `N` cameras in the PyTorch3D convention. cameras: A batch of `N` cameras in the PyTorch3D convention.
image_size: A tensor of shape `(N, 2)` containing the sizes of the images image_size: A tensor of shape `(N, 2)` containing the sizes of the images
(height, width) attached to each camera. (height, width) attached to each camera.
return_as_rotmat (bool): If set to True, return the full 3x3 rotation
matrices. Otherwise, return an axis-angle vector (default).
Returns: Returns:
R: A batch of rotation matrices of shape `(N, 3, 3)`. R: A batch of rotation matrices of shape `(N, 3, 3)`.
tvec: A batch of translation vectors of shape `(N, 3)`. tvec: A batch of translation vectors of shape `(N, 3)`.
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`. camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
""" """
R_pytorch3d = cameras.R R_pytorch3d = cameras.R.clone() # pyre-ignore
T_pytorch3d = cameras.T T_pytorch3d = cameras.T.clone() # pyre-ignore
focal_pytorch3d = cameras.focal_length focal_pytorch3d = cameras.focal_length
p0_pytorch3d = cameras.principal_point p0_pytorch3d = cameras.principal_point
T_pytorch3d[:, :2] *= -1 # pyre-ignore T_pytorch3d[:, :2] *= -1
R_pytorch3d[:, :, :2] *= -1 # pyre-ignore R_pytorch3d[:, :, :2] *= -1
tvec = T_pytorch3d.clone() # pyre-ignore tvec = T_pytorch3d
R = R_pytorch3d.permute(0, 2, 1) # pyre-ignore R = R_pytorch3d.permute(0, 2, 1)
# Retype the image_size correctly and flip to width, height. # Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,)) image_size_wh = image_size.to(R).flip(dims=(1,))
@ -130,3 +135,151 @@ def opencv_from_cameras_projection(
camera_matrix[:, 0, 0] = focal_length[:, 0] camera_matrix[:, 0, 0] = focal_length[:, 0]
camera_matrix[:, 1, 1] = focal_length[:, 1] camera_matrix[:, 1, 1] = focal_length[:, 1]
return R, tvec, camera_matrix return R, tvec, camera_matrix
def pulsar_from_opencv_projection(
R: torch.Tensor,
tvec: torch.Tensor,
camera_matrix: torch.Tensor,
image_size: torch.Tensor,
znear: float = 0.1,
) -> torch.Tensor:
"""
Convert OpenCV style camera parameters to Pulsar style camera parameters.
Note:
* Pulsar does NOT support different focal lengths for x and y.
For conversion, we use the average of fx and fy.
* The Pulsar renderer MUST use a left-handed coordinate system for this
mapping to work.
* The resulting image will be vertically flipped - which has to be
addressed AFTER rendering by the user.
* The parameters `R, tvec, camera_matrix` correspond to the outputs
of `cv2.decomposeProjectionMatrix`.
Args:
R: A batch of rotation matrices of shape `(N, 3, 3)`.
tvec: A batch of translation vectors of shape `(N, 3)`.
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
(height, width) attached to each camera.
znear (float): The near clipping value to use for Pulsar.
Returns:
cameras_pulsar: A batch of `N` Pulsar camera vectors in the Pulsar
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
c_x, c_y).
"""
assert len(camera_matrix.size()) == 3, "This function requires batched inputs!"
assert len(R.size()) == 3, "This function requires batched inputs!"
assert len(tvec.size()) in (2, 3), "This function reuqires batched inputs!"
# Validate parameters.
image_size_wh = image_size.to(R).flip(dims=(1,))
assert torch.all(
image_size_wh > 0
), "height and width must be positive but min is: %s" % (
str(image_size_wh.min().item())
)
assert (
camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
camera_matrix.size(1),
camera_matrix.size(2),
)
assert (
R.size(1) == 3 and R.size(2) == 3
), "Incorrect R shape: expected 3x3 but got %dx%d" % (
R.size(1),
R.size(2),
)
if len(tvec.size()) == 2:
tvec = tvec.unsqueeze(2)
assert (
tvec.size(1) == 3 and tvec.size(2) == 1
), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
tvec.size(1),
tvec.size(2),
)
# Check batch size.
batch_size = camera_matrix.size(0)
assert R.size(0) == batch_size, "Expected R to have batch size %d. Has size %d." % (
batch_size,
R.size(0),
)
assert (
tvec.size(0) == batch_size
), "Expected tvec to have batch size %d. Has size %d." % (
batch_size,
tvec.size(0),
)
# Check image sizes.
image_w = image_size_wh[0, 0]
image_h = image_size_wh[0, 1]
assert torch.all(
image_size_wh[:, 0] == image_w
), "All images in a batch must have the same width!"
assert torch.all(
image_size_wh[:, 1] == image_h
), "All images in a batch must have the same height!"
# Focal length.
fx = camera_matrix[:, 0, 0].unsqueeze(1)
fy = camera_matrix[:, 1, 1].unsqueeze(1)
# Check that we introduce less than 1% error by averaging the focal lengths.
fx_y = fx / fy
if torch.any(fx_y > 1.01) or torch.any(fx_y < 0.99):
LOGGER.warning(
"Pulsar only supports a single focal lengths. For converting OpenCV "
"focal lengths, we average them for x and y directions. "
"The focal lengths for x and y you provided differ by more than 1%, "
"which means this could introduce a noticeable error."
)
f = (fx + fy) / 2
# Normalize f into normalized device coordinates.
focal_length_px = f / image_w
# Transfer into focal_length and sensor_width.
focal_length = torch.tensor([znear - 1e-5], dtype=torch.float32, device=R.device)
focal_length = focal_length[None, :].repeat(batch_size, 1)
sensor_width = focal_length / focal_length_px
# Principal point.
cx = camera_matrix[:, 0, 2].unsqueeze(1)
cy = camera_matrix[:, 1, 2].unsqueeze(1)
# Transfer principal point offset into centered offset.
cx = -(cx - image_w / 2)
cy = cy - image_h / 2
# Concatenate to final vector.
param = torch.cat([focal_length, sensor_width, cx, cy], dim=1)
R_trans = R.permute(0, 2, 1)
cam_pos = -torch.bmm(R_trans, tvec).squeeze(2)
cam_rot = matrix_to_rotation_6d(R_trans)
cam_params = torch.cat([cam_pos, cam_rot, param], dim=1)
return cam_params
def pulsar_from_cameras_projection(
cameras: PerspectiveCameras,
image_size: torch.Tensor,
) -> torch.Tensor:
"""
Convert PyTorch3D `PerspectiveCameras` to Pulsar style camera parameters.
Note:
* Pulsar does NOT support different focal lengths for x and y.
For conversion, we use the average of fx and fy.
* The Pulsar renderer MUST use a left-handed coordinate system for this
mapping to work.
* The resulting image will be vertically flipped - which has to be
addressed AFTER rendering by the user.
Args:
cameras: A batch of `N` cameras in the PyTorch3D convention.
image_size: A tensor of shape `(N, 2)` containing the sizes of the images
(height, width) attached to each camera.
Returns:
cameras_pulsar: A batch of `N` Pulsar camera vectors in the Pulsar
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
c_x, c_y).
"""
opencv_R, opencv_T, opencv_K = opencv_from_cameras_projection(cameras, image_size)
return pulsar_from_opencv_projection(opencv_R, opencv_T, opencv_K, image_size)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.9 KiB

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.9 KiB

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 KiB

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 KiB

After

Width:  |  Height:  |  Size: 2.1 KiB

View File

@ -12,10 +12,12 @@ import numpy as np
import torch import torch
from common_testing import TestCaseMixin, get_tests_dir from common_testing import TestCaseMixin, get_tests_dir
from pytorch3d.ops import eyes from pytorch3d.ops import eyes
from pytorch3d.renderer.points.pulsar import Renderer as PulsarRenderer
from pytorch3d.transforms import so3_exp_map, so3_log_map from pytorch3d.transforms import so3_exp_map, so3_log_map
from pytorch3d.utils import ( from pytorch3d.utils import (
cameras_from_opencv_projection, cameras_from_opencv_projection,
opencv_from_cameras_projection, opencv_from_cameras_projection,
pulsar_from_opencv_projection,
) )
@ -111,6 +113,9 @@ class TestCameraConversions(TestCaseMixin, unittest.TestCase):
[105.0, 105.0], [105.0, 105.0],
[120.0, 120.0], [120.0, 120.0],
] ]
# These values are in y, x format, but they should be in x, y format.
# The tests work like this because they only test for consistency,
# but this format is misleading.
principal_point = [ principal_point = [
[240, 320], [240, 320],
[240.5, 320.3], [240.5, 320.3],
@ -160,3 +165,80 @@ class TestCameraConversions(TestCaseMixin, unittest.TestCase):
self.assertClose(R, R_i) self.assertClose(R, R_i)
self.assertClose(tvec, tvec_i) self.assertClose(tvec, tvec_i)
self.assertClose(camera_matrix, camera_matrix_i) self.assertClose(camera_matrix, camera_matrix_i)
def test_pulsar_conversion(self):
"""
Tests that the cameras converted from opencv to pulsar convention
return correct projections of random 3D points. The check is done
against a set of results precomputed using `cv2.projectPoints` function.
"""
image_size = [[480, 640]]
R = [
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
],
[
[0.1968, -0.6663, -0.7192],
[0.7138, -0.4055, 0.5710],
[-0.6721, -0.6258, 0.3959],
],
]
tvec = [
[10.0, 10.0, 3.0],
[-0.0, -0.0, 20.0],
]
focal_length = [
[100.0, 100.0],
[10.0, 10.0],
]
principal_point = [
[320, 240],
[320, 240],
]
principal_point, focal_length, R, tvec, image_size = [
torch.FloatTensor(x)
for x in (principal_point, focal_length, R, tvec, image_size)
]
camera_matrix = eyes(dim=3, N=2)
camera_matrix[:, 0, 0] = focal_length[:, 0]
camera_matrix[:, 1, 1] = focal_length[:, 1]
camera_matrix[:, :2, 2] = principal_point
rvec = so3_log_map(R)
pts = torch.tensor(
[[[0.0, 0.0, 120.0]], [[0.0, 0.0, 120.0]]], dtype=torch.float32
)
radii = torch.tensor([[1e-5], [1e-5]], dtype=torch.float32)
col = torch.zeros((2, 1, 1), dtype=torch.float32)
# project the 3D points with the opencv projection function
pts_proj_opencv = cv2_project_points(pts, rvec, tvec, camera_matrix)
pulsar_cam = pulsar_from_opencv_projection(
R, tvec, camera_matrix, image_size, znear=100.0
)
pulsar_rend = PulsarRenderer(
640, 480, 1, right_handed_system=False, n_channels=1
)
rendered = torch.flip(
pulsar_rend(
pts,
col,
radii,
pulsar_cam,
1e-5,
max_depth=150.0,
min_depth=100.0,
),
dims=(1,),
)
for batch_id in range(2):
point_pos = torch.where(rendered[batch_id] == rendered[batch_id].min())
point_pos = point_pos[1][0], point_pos[0][0]
self.assertLess(
torch.abs(point_pos[0] - pts_proj_opencv[batch_id, 0, 0]), 2
)
self.assertLess(
torch.abs(point_pos[1] - pts_proj_opencv[batch_id, 0, 1]), 2
)