Add OpenCV camera conversion; fix bug for camera unified PyTorch3D interface.
Summary: This commit adds a new camera conversion function for OpenCV style parameters to Pulsar parameters to the library. Using this function it addresses a bug reported here: https://fb.workplace.com/groups/629644647557365/posts/1079637302558095, by using the PyTorch3D->OpenCV->Pulsar chain instead of the original direct conversion function. Both conversions are well-tested and an additional test for the full chain has been added, resulting in a more reliable solution requiring less code. Reviewed By: patricklabatut Differential Revision: D29322106 fbshipit-source-id: 13df13c2e48f628f75d9f44f19ff7f1646fb7ebd
@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, Tuple, Union
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
from ....transforms import matrix_to_rotation_6d
 | 
			
		||||
from ....utils import pulsar_from_cameras_projection
 | 
			
		||||
from ...cameras import (
 | 
			
		||||
    FoVOrthographicCameras,
 | 
			
		||||
    FoVPerspectiveCameras,
 | 
			
		||||
@ -102,7 +102,7 @@ class PulsarPointsRenderer(nn.Module):
 | 
			
		||||
            height=height,
 | 
			
		||||
            max_num_balls=max_num_spheres,
 | 
			
		||||
            orthogonal_projection=orthogonal_projection,
 | 
			
		||||
            right_handed_system=True,
 | 
			
		||||
            right_handed_system=False,
 | 
			
		||||
            n_channels=n_channels,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
@ -359,24 +359,28 @@ class PulsarPointsRenderer(nn.Module):
 | 
			
		||||
    def _extract_extrinsics(
 | 
			
		||||
        self, kwargs, cloud_idx
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Extract the extrinsic information from the kwargs for a specific point cloud.
 | 
			
		||||
 | 
			
		||||
        Instead of implementing a direct translation from the PyTorch3D to the Pulsar
 | 
			
		||||
        camera model, we chain the two conversions of PyTorch3D->OpenCV and
 | 
			
		||||
        OpenCV->Pulsar for better maintainability (PyTorch3D->OpenCV is maintained and
 | 
			
		||||
        tested by the core PyTorch3D team, whereas OpenCV->Pulsar is maintained and
 | 
			
		||||
        tested by the Pulsar team).
 | 
			
		||||
        """
 | 
			
		||||
        # Shorthand:
 | 
			
		||||
        cameras = self.rasterizer.cameras
 | 
			
		||||
        R = kwargs.get("R", cameras.R)[cloud_idx]
 | 
			
		||||
        T = kwargs.get("T", cameras.T)[cloud_idx]
 | 
			
		||||
        norm_mat = torch.tensor(
 | 
			
		||||
            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=R.device,
 | 
			
		||||
        tmp_cams = PerspectiveCameras(
 | 
			
		||||
            R=R.unsqueeze(0), T=T.unsqueeze(0), device=R.device
 | 
			
		||||
        )
 | 
			
		||||
        cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...]).permute((0, 2, 1))
 | 
			
		||||
        norm_mat = torch.tensor(
 | 
			
		||||
            [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=R.device,
 | 
			
		||||
        size_tensor = torch.tensor(
 | 
			
		||||
            [[self.renderer._renderer.height, self.renderer._renderer.width]]
 | 
			
		||||
        )
 | 
			
		||||
        cam_rot = torch.matmul(norm_mat, cam_rot)
 | 
			
		||||
        cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None]))
 | 
			
		||||
        cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot))
 | 
			
		||||
        pulsar_cam = pulsar_from_cameras_projection(tmp_cams, size_tensor)
 | 
			
		||||
        cam_pos = pulsar_cam[0, :3]
 | 
			
		||||
        cam_rot = pulsar_cam[0, 3:9]
 | 
			
		||||
        return cam_pos, cam_rot
 | 
			
		||||
 | 
			
		||||
    def _get_vert_rad(
 | 
			
		||||
@ -547,15 +551,17 @@ class PulsarPointsRenderer(nn.Module):
 | 
			
		||||
                otherargs["bg_col"] = bg_col
 | 
			
		||||
            # Go!
 | 
			
		||||
            images.append(
 | 
			
		||||
                self.renderer(
 | 
			
		||||
                    vert_pos=vert_pos,
 | 
			
		||||
                    vert_col=vert_col,
 | 
			
		||||
                    vert_rad=vert_rad,
 | 
			
		||||
                    cam_params=cam_params,
 | 
			
		||||
                    gamma=gamma,
 | 
			
		||||
                    max_depth=zfar,
 | 
			
		||||
                    min_depth=znear,
 | 
			
		||||
                    **otherargs,
 | 
			
		||||
                torch.flipud(
 | 
			
		||||
                    self.renderer(
 | 
			
		||||
                        vert_pos=vert_pos,
 | 
			
		||||
                        vert_col=vert_col,
 | 
			
		||||
                        vert_rad=vert_rad,
 | 
			
		||||
                        cam_params=cam_params,
 | 
			
		||||
                        gamma=gamma,
 | 
			
		||||
                        max_depth=zfar,
 | 
			
		||||
                        min_depth=znear,
 | 
			
		||||
                        **otherargs,
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        return torch.stack(images, dim=0)
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,8 @@
 | 
			
		||||
from .camera_conversions import (
 | 
			
		||||
    cameras_from_opencv_projection,
 | 
			
		||||
    opencv_from_cameras_projection,
 | 
			
		||||
    pulsar_from_opencv_projection,
 | 
			
		||||
    pulsar_from_cameras_projection,
 | 
			
		||||
)
 | 
			
		||||
from .ico_sphere import ico_sphere
 | 
			
		||||
from .torus import torus
 | 
			
		||||
 | 
			
		||||
@ -4,12 +4,16 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from ..renderer import PerspectiveCameras
 | 
			
		||||
from ..transforms import so3_exp_map, so3_log_map
 | 
			
		||||
from ..transforms import matrix_to_rotation_6d
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LOGGER = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cameras_from_opencv_projection(
 | 
			
		||||
@ -54,7 +58,6 @@ def cameras_from_opencv_projection(
 | 
			
		||||
    Returns:
 | 
			
		||||
        cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
 | 
			
		||||
    principal_point = camera_matrix[:, :2, 2]
 | 
			
		||||
 | 
			
		||||
@ -68,7 +71,7 @@ def cameras_from_opencv_projection(
 | 
			
		||||
    # For R, T we flip x, y axes (opencv screen space has an opposite
 | 
			
		||||
    # orientation of screen axes).
 | 
			
		||||
    # We also transpose R (opencv multiplies points from the opposite=left side).
 | 
			
		||||
    R_pytorch3d = R.permute(0, 2, 1)
 | 
			
		||||
    R_pytorch3d = R.clone().permute(0, 2, 1)
 | 
			
		||||
    T_pytorch3d = tvec.clone()
 | 
			
		||||
    R_pytorch3d[:, :, :2] *= -1
 | 
			
		||||
    T_pytorch3d[:, :2] *= -1
 | 
			
		||||
@ -103,20 +106,22 @@ def opencv_from_cameras_projection(
 | 
			
		||||
        cameras: A batch of `N` cameras in the PyTorch3D convention.
 | 
			
		||||
        image_size: A tensor of shape `(N, 2)` containing the sizes of the images
 | 
			
		||||
            (height, width) attached to each camera.
 | 
			
		||||
        return_as_rotmat (bool): If set to True, return the full 3x3 rotation
 | 
			
		||||
            matrices. Otherwise, return an axis-angle vector (default).
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        R: A batch of rotation matrices of shape `(N, 3, 3)`.
 | 
			
		||||
        tvec: A batch of translation vectors of shape `(N, 3)`.
 | 
			
		||||
        camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
 | 
			
		||||
    """
 | 
			
		||||
    R_pytorch3d = cameras.R
 | 
			
		||||
    T_pytorch3d = cameras.T
 | 
			
		||||
    R_pytorch3d = cameras.R.clone()  # pyre-ignore
 | 
			
		||||
    T_pytorch3d = cameras.T.clone()  # pyre-ignore
 | 
			
		||||
    focal_pytorch3d = cameras.focal_length
 | 
			
		||||
    p0_pytorch3d = cameras.principal_point
 | 
			
		||||
    T_pytorch3d[:, :2] *= -1  # pyre-ignore
 | 
			
		||||
    R_pytorch3d[:, :, :2] *= -1  # pyre-ignore
 | 
			
		||||
    tvec = T_pytorch3d.clone()  # pyre-ignore
 | 
			
		||||
    R = R_pytorch3d.permute(0, 2, 1)  # pyre-ignore
 | 
			
		||||
    T_pytorch3d[:, :2] *= -1
 | 
			
		||||
    R_pytorch3d[:, :, :2] *= -1
 | 
			
		||||
    tvec = T_pytorch3d
 | 
			
		||||
    R = R_pytorch3d.permute(0, 2, 1)
 | 
			
		||||
 | 
			
		||||
    # Retype the image_size correctly and flip to width, height.
 | 
			
		||||
    image_size_wh = image_size.to(R).flip(dims=(1,))
 | 
			
		||||
@ -130,3 +135,151 @@ def opencv_from_cameras_projection(
 | 
			
		||||
    camera_matrix[:, 0, 0] = focal_length[:, 0]
 | 
			
		||||
    camera_matrix[:, 1, 1] = focal_length[:, 1]
 | 
			
		||||
    return R, tvec, camera_matrix
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pulsar_from_opencv_projection(
 | 
			
		||||
    R: torch.Tensor,
 | 
			
		||||
    tvec: torch.Tensor,
 | 
			
		||||
    camera_matrix: torch.Tensor,
 | 
			
		||||
    image_size: torch.Tensor,
 | 
			
		||||
    znear: float = 0.1,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Convert OpenCV style camera parameters to Pulsar style camera parameters.
 | 
			
		||||
 | 
			
		||||
    Note:
 | 
			
		||||
        * Pulsar does NOT support different focal lengths for x and y.
 | 
			
		||||
          For conversion, we use the average of fx and fy.
 | 
			
		||||
        * The Pulsar renderer MUST use a left-handed coordinate system for this
 | 
			
		||||
          mapping to work.
 | 
			
		||||
        * The resulting image will be vertically flipped - which has to be
 | 
			
		||||
          addressed AFTER rendering by the user.
 | 
			
		||||
        * The parameters `R, tvec, camera_matrix` correspond to the outputs
 | 
			
		||||
          of `cv2.decomposeProjectionMatrix`.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        R: A batch of rotation matrices of shape `(N, 3, 3)`.
 | 
			
		||||
        tvec: A batch of translation vectors of shape `(N, 3)`.
 | 
			
		||||
        camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
 | 
			
		||||
        image_size: A tensor of shape `(N, 2)` containing the sizes of the images
 | 
			
		||||
            (height, width) attached to each camera.
 | 
			
		||||
        znear (float): The near clipping value to use for Pulsar.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        cameras_pulsar: A batch of `N` Pulsar camera vectors in the Pulsar
 | 
			
		||||
            convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
 | 
			
		||||
            c_x, c_y).
 | 
			
		||||
    """
 | 
			
		||||
    assert len(camera_matrix.size()) == 3, "This function requires batched inputs!"
 | 
			
		||||
    assert len(R.size()) == 3, "This function requires batched inputs!"
 | 
			
		||||
    assert len(tvec.size()) in (2, 3), "This function reuqires batched inputs!"
 | 
			
		||||
 | 
			
		||||
    # Validate parameters.
 | 
			
		||||
    image_size_wh = image_size.to(R).flip(dims=(1,))
 | 
			
		||||
    assert torch.all(
 | 
			
		||||
        image_size_wh > 0
 | 
			
		||||
    ), "height and width must be positive but min is: %s" % (
 | 
			
		||||
        str(image_size_wh.min().item())
 | 
			
		||||
    )
 | 
			
		||||
    assert (
 | 
			
		||||
        camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
 | 
			
		||||
    ), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
 | 
			
		||||
        camera_matrix.size(1),
 | 
			
		||||
        camera_matrix.size(2),
 | 
			
		||||
    )
 | 
			
		||||
    assert (
 | 
			
		||||
        R.size(1) == 3 and R.size(2) == 3
 | 
			
		||||
    ), "Incorrect R shape: expected 3x3 but got %dx%d" % (
 | 
			
		||||
        R.size(1),
 | 
			
		||||
        R.size(2),
 | 
			
		||||
    )
 | 
			
		||||
    if len(tvec.size()) == 2:
 | 
			
		||||
        tvec = tvec.unsqueeze(2)
 | 
			
		||||
    assert (
 | 
			
		||||
        tvec.size(1) == 3 and tvec.size(2) == 1
 | 
			
		||||
    ), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
 | 
			
		||||
        tvec.size(1),
 | 
			
		||||
        tvec.size(2),
 | 
			
		||||
    )
 | 
			
		||||
    # Check batch size.
 | 
			
		||||
    batch_size = camera_matrix.size(0)
 | 
			
		||||
    assert R.size(0) == batch_size, "Expected R to have batch size %d. Has size %d." % (
 | 
			
		||||
        batch_size,
 | 
			
		||||
        R.size(0),
 | 
			
		||||
    )
 | 
			
		||||
    assert (
 | 
			
		||||
        tvec.size(0) == batch_size
 | 
			
		||||
    ), "Expected tvec to have batch size %d. Has size %d." % (
 | 
			
		||||
        batch_size,
 | 
			
		||||
        tvec.size(0),
 | 
			
		||||
    )
 | 
			
		||||
    # Check image sizes.
 | 
			
		||||
    image_w = image_size_wh[0, 0]
 | 
			
		||||
    image_h = image_size_wh[0, 1]
 | 
			
		||||
    assert torch.all(
 | 
			
		||||
        image_size_wh[:, 0] == image_w
 | 
			
		||||
    ), "All images in a batch must have the same width!"
 | 
			
		||||
    assert torch.all(
 | 
			
		||||
        image_size_wh[:, 1] == image_h
 | 
			
		||||
    ), "All images in a batch must have the same height!"
 | 
			
		||||
    # Focal length.
 | 
			
		||||
    fx = camera_matrix[:, 0, 0].unsqueeze(1)
 | 
			
		||||
    fy = camera_matrix[:, 1, 1].unsqueeze(1)
 | 
			
		||||
    # Check that we introduce less than 1% error by averaging the focal lengths.
 | 
			
		||||
    fx_y = fx / fy
 | 
			
		||||
    if torch.any(fx_y > 1.01) or torch.any(fx_y < 0.99):
 | 
			
		||||
        LOGGER.warning(
 | 
			
		||||
            "Pulsar only supports a single focal lengths. For converting OpenCV "
 | 
			
		||||
            "focal lengths, we average them for x and y directions. "
 | 
			
		||||
            "The focal lengths for x and y you provided differ by more than 1%, "
 | 
			
		||||
            "which means this could introduce a noticeable error."
 | 
			
		||||
        )
 | 
			
		||||
    f = (fx + fy) / 2
 | 
			
		||||
    # Normalize f into normalized device coordinates.
 | 
			
		||||
    focal_length_px = f / image_w
 | 
			
		||||
    # Transfer into focal_length and sensor_width.
 | 
			
		||||
    focal_length = torch.tensor([znear - 1e-5], dtype=torch.float32, device=R.device)
 | 
			
		||||
    focal_length = focal_length[None, :].repeat(batch_size, 1)
 | 
			
		||||
    sensor_width = focal_length / focal_length_px
 | 
			
		||||
    # Principal point.
 | 
			
		||||
    cx = camera_matrix[:, 0, 2].unsqueeze(1)
 | 
			
		||||
    cy = camera_matrix[:, 1, 2].unsqueeze(1)
 | 
			
		||||
    # Transfer principal point offset into centered offset.
 | 
			
		||||
    cx = -(cx - image_w / 2)
 | 
			
		||||
    cy = cy - image_h / 2
 | 
			
		||||
    # Concatenate to final vector.
 | 
			
		||||
    param = torch.cat([focal_length, sensor_width, cx, cy], dim=1)
 | 
			
		||||
    R_trans = R.permute(0, 2, 1)
 | 
			
		||||
    cam_pos = -torch.bmm(R_trans, tvec).squeeze(2)
 | 
			
		||||
    cam_rot = matrix_to_rotation_6d(R_trans)
 | 
			
		||||
    cam_params = torch.cat([cam_pos, cam_rot, param], dim=1)
 | 
			
		||||
    return cam_params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pulsar_from_cameras_projection(
 | 
			
		||||
    cameras: PerspectiveCameras,
 | 
			
		||||
    image_size: torch.Tensor,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Convert PyTorch3D `PerspectiveCameras` to Pulsar style camera parameters.
 | 
			
		||||
 | 
			
		||||
    Note:
 | 
			
		||||
        * Pulsar does NOT support different focal lengths for x and y.
 | 
			
		||||
          For conversion, we use the average of fx and fy.
 | 
			
		||||
        * The Pulsar renderer MUST use a left-handed coordinate system for this
 | 
			
		||||
          mapping to work.
 | 
			
		||||
        * The resulting image will be vertically flipped - which has to be
 | 
			
		||||
          addressed AFTER rendering by the user.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cameras: A batch of `N` cameras in the PyTorch3D convention.
 | 
			
		||||
        image_size: A tensor of shape `(N, 2)` containing the sizes of the images
 | 
			
		||||
            (height, width) attached to each camera.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        cameras_pulsar: A batch of `N` Pulsar camera vectors in the Pulsar
 | 
			
		||||
            convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
 | 
			
		||||
            c_x, c_y).
 | 
			
		||||
    """
 | 
			
		||||
    opencv_R, opencv_T, opencv_K = opencv_from_cameras_projection(cameras, image_size)
 | 
			
		||||
    return pulsar_from_opencv_projection(opencv_R, opencv_T, opencv_K, image_size)
 | 
			
		||||
 | 
			
		||||
| 
		 Before Width: | Height: | Size: 1.9 KiB After Width: | Height: | Size: 1.9 KiB  | 
| 
		 Before Width: | Height: | Size: 1.9 KiB After Width: | Height: | Size: 1.9 KiB  | 
| 
		 Before Width: | Height: | Size: 2.1 KiB After Width: | Height: | Size: 2.1 KiB  | 
| 
		 Before Width: | Height: | Size: 2.1 KiB After Width: | Height: | Size: 2.1 KiB  | 
@ -12,10 +12,12 @@ import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin, get_tests_dir
 | 
			
		||||
from pytorch3d.ops import eyes
 | 
			
		||||
from pytorch3d.renderer.points.pulsar import Renderer as PulsarRenderer
 | 
			
		||||
from pytorch3d.transforms import so3_exp_map, so3_log_map
 | 
			
		||||
from pytorch3d.utils import (
 | 
			
		||||
    cameras_from_opencv_projection,
 | 
			
		||||
    opencv_from_cameras_projection,
 | 
			
		||||
    pulsar_from_opencv_projection,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -111,6 +113,9 @@ class TestCameraConversions(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            [105.0, 105.0],
 | 
			
		||||
            [120.0, 120.0],
 | 
			
		||||
        ]
 | 
			
		||||
        # These values are in y, x format, but they should be in x, y format.
 | 
			
		||||
        # The tests work like this because they only test for consistency,
 | 
			
		||||
        # but this format is misleading.
 | 
			
		||||
        principal_point = [
 | 
			
		||||
            [240, 320],
 | 
			
		||||
            [240.5, 320.3],
 | 
			
		||||
@ -160,3 +165,80 @@ class TestCameraConversions(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(R, R_i)
 | 
			
		||||
        self.assertClose(tvec, tvec_i)
 | 
			
		||||
        self.assertClose(camera_matrix, camera_matrix_i)
 | 
			
		||||
 | 
			
		||||
    def test_pulsar_conversion(self):
 | 
			
		||||
        """
 | 
			
		||||
        Tests that the cameras converted from opencv to pulsar convention
 | 
			
		||||
        return correct projections of random 3D points. The check is done
 | 
			
		||||
        against a set of results precomputed using `cv2.projectPoints` function.
 | 
			
		||||
        """
 | 
			
		||||
        image_size = [[480, 640]]
 | 
			
		||||
        R = [
 | 
			
		||||
            [
 | 
			
		||||
                [1.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 1.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, 1.0],
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                [0.1968, -0.6663, -0.7192],
 | 
			
		||||
                [0.7138, -0.4055, 0.5710],
 | 
			
		||||
                [-0.6721, -0.6258, 0.3959],
 | 
			
		||||
            ],
 | 
			
		||||
        ]
 | 
			
		||||
        tvec = [
 | 
			
		||||
            [10.0, 10.0, 3.0],
 | 
			
		||||
            [-0.0, -0.0, 20.0],
 | 
			
		||||
        ]
 | 
			
		||||
        focal_length = [
 | 
			
		||||
            [100.0, 100.0],
 | 
			
		||||
            [10.0, 10.0],
 | 
			
		||||
        ]
 | 
			
		||||
        principal_point = [
 | 
			
		||||
            [320, 240],
 | 
			
		||||
            [320, 240],
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        principal_point, focal_length, R, tvec, image_size = [
 | 
			
		||||
            torch.FloatTensor(x)
 | 
			
		||||
            for x in (principal_point, focal_length, R, tvec, image_size)
 | 
			
		||||
        ]
 | 
			
		||||
        camera_matrix = eyes(dim=3, N=2)
 | 
			
		||||
        camera_matrix[:, 0, 0] = focal_length[:, 0]
 | 
			
		||||
        camera_matrix[:, 1, 1] = focal_length[:, 1]
 | 
			
		||||
        camera_matrix[:, :2, 2] = principal_point
 | 
			
		||||
        rvec = so3_log_map(R)
 | 
			
		||||
        pts = torch.tensor(
 | 
			
		||||
            [[[0.0, 0.0, 120.0]], [[0.0, 0.0, 120.0]]], dtype=torch.float32
 | 
			
		||||
        )
 | 
			
		||||
        radii = torch.tensor([[1e-5], [1e-5]], dtype=torch.float32)
 | 
			
		||||
        col = torch.zeros((2, 1, 1), dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
        # project the 3D points with the opencv projection function
 | 
			
		||||
        pts_proj_opencv = cv2_project_points(pts, rvec, tvec, camera_matrix)
 | 
			
		||||
        pulsar_cam = pulsar_from_opencv_projection(
 | 
			
		||||
            R, tvec, camera_matrix, image_size, znear=100.0
 | 
			
		||||
        )
 | 
			
		||||
        pulsar_rend = PulsarRenderer(
 | 
			
		||||
            640, 480, 1, right_handed_system=False, n_channels=1
 | 
			
		||||
        )
 | 
			
		||||
        rendered = torch.flip(
 | 
			
		||||
            pulsar_rend(
 | 
			
		||||
                pts,
 | 
			
		||||
                col,
 | 
			
		||||
                radii,
 | 
			
		||||
                pulsar_cam,
 | 
			
		||||
                1e-5,
 | 
			
		||||
                max_depth=150.0,
 | 
			
		||||
                min_depth=100.0,
 | 
			
		||||
            ),
 | 
			
		||||
            dims=(1,),
 | 
			
		||||
        )
 | 
			
		||||
        for batch_id in range(2):
 | 
			
		||||
            point_pos = torch.where(rendered[batch_id] == rendered[batch_id].min())
 | 
			
		||||
            point_pos = point_pos[1][0], point_pos[0][0]
 | 
			
		||||
            self.assertLess(
 | 
			
		||||
                torch.abs(point_pos[0] - pts_proj_opencv[batch_id, 0, 0]), 2
 | 
			
		||||
            )
 | 
			
		||||
            self.assertLess(
 | 
			
		||||
                torch.abs(point_pos[1] - pts_proj_opencv[batch_id, 0, 1]), 2
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||