Add OpenCV camera conversion; fix bug for camera unified PyTorch3D interface.

Summary: This commit adds a new camera conversion function for OpenCV style parameters to Pulsar parameters to the library. Using this function it addresses a bug reported here: https://fb.workplace.com/groups/629644647557365/posts/1079637302558095, by using the PyTorch3D->OpenCV->Pulsar chain instead of the original direct conversion function. Both conversions are well-tested and an additional test for the full chain has been added, resulting in a more reliable solution requiring less code.

Reviewed By: patricklabatut

Differential Revision: D29322106

fbshipit-source-id: 13df13c2e48f628f75d9f44f19ff7f1646fb7ebd
This commit is contained in:
Christoph Lassner
2021-07-10 01:05:36 -07:00
committed by Facebook GitHub Bot
parent fef5bcd8f9
commit 75432a0695
8 changed files with 275 additions and 32 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.9 KiB

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.9 KiB

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 KiB

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 KiB

After

Width:  |  Height:  |  Size: 2.1 KiB

View File

@@ -12,10 +12,12 @@ import numpy as np
import torch
from common_testing import TestCaseMixin, get_tests_dir
from pytorch3d.ops import eyes
from pytorch3d.renderer.points.pulsar import Renderer as PulsarRenderer
from pytorch3d.transforms import so3_exp_map, so3_log_map
from pytorch3d.utils import (
cameras_from_opencv_projection,
opencv_from_cameras_projection,
pulsar_from_opencv_projection,
)
@@ -111,6 +113,9 @@ class TestCameraConversions(TestCaseMixin, unittest.TestCase):
[105.0, 105.0],
[120.0, 120.0],
]
# These values are in y, x format, but they should be in x, y format.
# The tests work like this because they only test for consistency,
# but this format is misleading.
principal_point = [
[240, 320],
[240.5, 320.3],
@@ -160,3 +165,80 @@ class TestCameraConversions(TestCaseMixin, unittest.TestCase):
self.assertClose(R, R_i)
self.assertClose(tvec, tvec_i)
self.assertClose(camera_matrix, camera_matrix_i)
def test_pulsar_conversion(self):
"""
Tests that the cameras converted from opencv to pulsar convention
return correct projections of random 3D points. The check is done
against a set of results precomputed using `cv2.projectPoints` function.
"""
image_size = [[480, 640]]
R = [
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
],
[
[0.1968, -0.6663, -0.7192],
[0.7138, -0.4055, 0.5710],
[-0.6721, -0.6258, 0.3959],
],
]
tvec = [
[10.0, 10.0, 3.0],
[-0.0, -0.0, 20.0],
]
focal_length = [
[100.0, 100.0],
[10.0, 10.0],
]
principal_point = [
[320, 240],
[320, 240],
]
principal_point, focal_length, R, tvec, image_size = [
torch.FloatTensor(x)
for x in (principal_point, focal_length, R, tvec, image_size)
]
camera_matrix = eyes(dim=3, N=2)
camera_matrix[:, 0, 0] = focal_length[:, 0]
camera_matrix[:, 1, 1] = focal_length[:, 1]
camera_matrix[:, :2, 2] = principal_point
rvec = so3_log_map(R)
pts = torch.tensor(
[[[0.0, 0.0, 120.0]], [[0.0, 0.0, 120.0]]], dtype=torch.float32
)
radii = torch.tensor([[1e-5], [1e-5]], dtype=torch.float32)
col = torch.zeros((2, 1, 1), dtype=torch.float32)
# project the 3D points with the opencv projection function
pts_proj_opencv = cv2_project_points(pts, rvec, tvec, camera_matrix)
pulsar_cam = pulsar_from_opencv_projection(
R, tvec, camera_matrix, image_size, znear=100.0
)
pulsar_rend = PulsarRenderer(
640, 480, 1, right_handed_system=False, n_channels=1
)
rendered = torch.flip(
pulsar_rend(
pts,
col,
radii,
pulsar_cam,
1e-5,
max_depth=150.0,
min_depth=100.0,
),
dims=(1,),
)
for batch_id in range(2):
point_pos = torch.where(rendered[batch_id] == rendered[batch_id].min())
point_pos = point_pos[1][0], point_pos[0][0]
self.assertLess(
torch.abs(point_pos[0] - pts_proj_opencv[batch_id, 0, 0]), 2
)
self.assertLess(
torch.abs(point_pos[1] - pts_proj_opencv[batch_id, 0, 1]), 2
)