mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Conversion from OpenCV cameras
Summary: Implements a conversion function between OpenCV and PyTorch3D cameras. Reviewed By: patricklabatut Differential Revision: D28992470 fbshipit-source-id: dbcc9f213ec293c2f6938261c704aea09aad3c90
This commit is contained in:
		
							parent
							
								
									b2ac2655b3
								
							
						
					
					
						commit
						8006842f2a
					
				@ -1,5 +1,6 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from .camera_conversions import cameras_from_opencv_projection
 | 
			
		||||
from .ico_sphere import ico_sphere
 | 
			
		||||
from .torus import torus
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										70
									
								
								pytorch3d/utils/camera_conversions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								pytorch3d/utils/camera_conversions.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,70 @@
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from ..renderer import PerspectiveCameras
 | 
			
		||||
from ..transforms import so3_exponential_map
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cameras_from_opencv_projection(
 | 
			
		||||
    rvec: torch.Tensor,
 | 
			
		||||
    tvec: torch.Tensor,
 | 
			
		||||
    camera_matrix: torch.Tensor,
 | 
			
		||||
    image_size: torch.Tensor,
 | 
			
		||||
) -> PerspectiveCameras:
 | 
			
		||||
    """
 | 
			
		||||
    Converts a batch of OpenCV-conventioned cameras parametrized with the
 | 
			
		||||
    axis-angle rotation vectors `rvec`, translation vectors `tvec`, and the camera
 | 
			
		||||
    calibration matrices `camera_matrix` to `PerspectiveCameras` in PyTorch3D
 | 
			
		||||
    convention.
 | 
			
		||||
 | 
			
		||||
    More specifically, the conversion is carried out such that a projection
 | 
			
		||||
    of a 3D shape to the OpenCV-conventioned screen of size `image_size` results
 | 
			
		||||
    in the same image as a projection with the corresponding PyTorch3D camera
 | 
			
		||||
    to the NDC screen convention of PyTorch3D.
 | 
			
		||||
 | 
			
		||||
    More specifically, the OpenCV convention projects points to the OpenCV screen
 | 
			
		||||
    space as follows:
 | 
			
		||||
        ```
 | 
			
		||||
        x_screen_opencv = camera_matrix @ (exp(rvec) @ x_world + tvec)
 | 
			
		||||
        ```
 | 
			
		||||
    followed by the homogenization of `x_screen_opencv`.
 | 
			
		||||
 | 
			
		||||
    Note:
 | 
			
		||||
        The parameters `rvec, tvec, camera_matrix` correspond e.g. to the inputs
 | 
			
		||||
        of `cv2.projectPoints`, or to the ouputs of `cv2.calibrateCamera`.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        rvec: A batch of axis-angle rotation vectors of shape `(N, 3)`.
 | 
			
		||||
        tvec: A batch of translation vectors of shape `(N, 3)`.
 | 
			
		||||
        camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
 | 
			
		||||
        image_size: A tensor of shape `(N, 2)` containing the sizes of the images
 | 
			
		||||
            (height, width) attached to each camera.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    R = so3_exponential_map(rvec)
 | 
			
		||||
    focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
 | 
			
		||||
    principal_point = camera_matrix[:, :2, 2]
 | 
			
		||||
 | 
			
		||||
    # Retype the image_size correctly and flip to width, height.
 | 
			
		||||
    image_size_wh = image_size.to(R).flip(dims=(1,))
 | 
			
		||||
 | 
			
		||||
    # Get the PyTorch3D focal length and principal point.
 | 
			
		||||
    focal_pytorch3d = focal_length / (0.5 * image_size_wh)
 | 
			
		||||
    p0_pytorch3d = -(principal_point / (0.5 * image_size_wh) - 1)
 | 
			
		||||
 | 
			
		||||
    # For R, T we flip x, y axes (opencv screen space has an opposite
 | 
			
		||||
    # orientation of screen axes).
 | 
			
		||||
    # We also transpose R (opencv multiplies points from the opposite=left side).
 | 
			
		||||
    R_pytorch3d = R.permute(0, 2, 1)
 | 
			
		||||
    T_pytorch3d = tvec.clone()
 | 
			
		||||
    R_pytorch3d[:, :, :2] *= -1
 | 
			
		||||
    T_pytorch3d[:, :2] *= -1
 | 
			
		||||
 | 
			
		||||
    return PerspectiveCameras(
 | 
			
		||||
        R=R_pytorch3d,
 | 
			
		||||
        T=T_pytorch3d,
 | 
			
		||||
        focal_length=focal_pytorch3d,
 | 
			
		||||
        principal_point=p0_pytorch3d,
 | 
			
		||||
    )
 | 
			
		||||
							
								
								
									
										1230
									
								
								tests/data/cv_project_points_precomputed.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1230
									
								
								tests/data/cv_project_points_precomputed.json
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										149
									
								
								tests/test_camera_conversions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										149
									
								
								tests/test_camera_conversions.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,149 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin, get_tests_dir
 | 
			
		||||
from pytorch3d.ops import eyes
 | 
			
		||||
from pytorch3d.transforms import so3_exponential_map, so3_log_map
 | 
			
		||||
from pytorch3d.utils import (
 | 
			
		||||
    cameras_from_opencv_projection,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
DATA_DIR = get_tests_dir() / "data"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _coords_opencv_screen_to_pytorch3d_ndc(xy_opencv, image_size):
 | 
			
		||||
    """
 | 
			
		||||
    Converts the OpenCV screen coordinates `xy_opencv` to PyTorch3D NDC coordinates.
 | 
			
		||||
    """
 | 
			
		||||
    xy_pytorch3d = -(2.0 * xy_opencv / image_size.flip(dims=(1,))[:, None] - 1.0)
 | 
			
		||||
    return xy_pytorch3d
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cv2_project_points(pts, rvec, tvec, camera_matrix):
 | 
			
		||||
    """
 | 
			
		||||
    Reproduces the `cv2.projectPoints` function from OpenCV using PyTorch.
 | 
			
		||||
    """
 | 
			
		||||
    R = so3_exponential_map(rvec)
 | 
			
		||||
    pts_proj_3d = (
 | 
			
		||||
        camera_matrix.bmm(R.bmm(pts.permute(0, 2, 1)) + tvec[:, :, None])
 | 
			
		||||
    ).permute(0, 2, 1)
 | 
			
		||||
    depth = pts_proj_3d[..., 2:]
 | 
			
		||||
    pts_proj_2d = pts_proj_3d[..., :2] / depth
 | 
			
		||||
    return pts_proj_2d
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestCameraConversions(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        torch.manual_seed(42)
 | 
			
		||||
        np.random.seed(42)
 | 
			
		||||
 | 
			
		||||
    def test_cv2_project_points(self):
 | 
			
		||||
        """
 | 
			
		||||
        Tests that the local implementation of cv2_project_points gives the same
 | 
			
		||||
        restults OpenCV's `cv2.projectPoints`. The check is done against a set
 | 
			
		||||
        of precomputed results `cv_project_points_precomputed`.
 | 
			
		||||
        """
 | 
			
		||||
        with open(DATA_DIR / "cv_project_points_precomputed.json", "r") as f:
 | 
			
		||||
            cv_project_points_precomputed = json.load(f)
 | 
			
		||||
 | 
			
		||||
        for test_case in cv_project_points_precomputed:
 | 
			
		||||
            _pts_proj = cv2_project_points(
 | 
			
		||||
                **{
 | 
			
		||||
                    k: torch.tensor(test_case[k])[None]
 | 
			
		||||
                    for k in ("pts", "rvec", "tvec", "camera_matrix")
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
            pts_proj = torch.tensor(test_case["pts_proj"])[None]
 | 
			
		||||
            self.assertClose(_pts_proj, pts_proj, atol=1e-4)
 | 
			
		||||
 | 
			
		||||
    def test_opencv_conversion(self):
 | 
			
		||||
        """
 | 
			
		||||
        Tests that the cameras converted from opencv to pytorch3d convention
 | 
			
		||||
        return correct projections of random 3D points. The check is done
 | 
			
		||||
        against a set of results precomuted using `cv2.projectPoints` function.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        image_size = [[480, 640]] * 4
 | 
			
		||||
        R = [
 | 
			
		||||
            [
 | 
			
		||||
                [1.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 1.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, 1.0],
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                [1.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, -1.0],
 | 
			
		||||
                [0.0, 1.0, 0.0],
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                [0.0, 0.0, 1.0],
 | 
			
		||||
                [1.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 1.0, 0.0],
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                [0.0, 0.0, 1.0],
 | 
			
		||||
                [1.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 1.0, 0.0],
 | 
			
		||||
            ],
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        tvec = [
 | 
			
		||||
            [0.0, 0.0, 3.0],
 | 
			
		||||
            [0.3, -0.3, 3.0],
 | 
			
		||||
            [-0.15, 0.1, 4.0],
 | 
			
		||||
            [0.0, 0.0, 4.0],
 | 
			
		||||
        ]
 | 
			
		||||
        focal_length = [
 | 
			
		||||
            [100.0, 100.0],
 | 
			
		||||
            [115.0, 115.0],
 | 
			
		||||
            [105.0, 105.0],
 | 
			
		||||
            [120.0, 120.0],
 | 
			
		||||
        ]
 | 
			
		||||
        principal_point = [
 | 
			
		||||
            [240, 320],
 | 
			
		||||
            [240.5, 320.3],
 | 
			
		||||
            [241, 318],
 | 
			
		||||
            [242, 322],
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        principal_point, focal_length, R, tvec, image_size = [
 | 
			
		||||
            torch.FloatTensor(x)
 | 
			
		||||
            for x in (principal_point, focal_length, R, tvec, image_size)
 | 
			
		||||
        ]
 | 
			
		||||
        camera_matrix = eyes(dim=3, N=4)
 | 
			
		||||
        camera_matrix[:, 0, 0], camera_matrix[:, 1, 1] = (
 | 
			
		||||
            focal_length[:, 0],
 | 
			
		||||
            focal_length[:, 1],
 | 
			
		||||
        )
 | 
			
		||||
        camera_matrix[:, :2, 2] = principal_point
 | 
			
		||||
 | 
			
		||||
        rvec = so3_log_map(R)
 | 
			
		||||
 | 
			
		||||
        pts = torch.nn.functional.normalize(torch.randn(4, 1000, 3), dim=-1)
 | 
			
		||||
 | 
			
		||||
        # project the 3D points with the opencv projection function
 | 
			
		||||
        pts_proj_opencv = cv2_project_points(pts, rvec, tvec, camera_matrix)
 | 
			
		||||
 | 
			
		||||
        # make the pytorch3d cameras
 | 
			
		||||
        cameras_opencv_to_pytorch3d = cameras_from_opencv_projection(
 | 
			
		||||
            rvec, tvec, camera_matrix, image_size
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # project the 3D points with converted cameras
 | 
			
		||||
        pts_proj_pytorch3d = cameras_opencv_to_pytorch3d.transform_points(pts)[..., :2]
 | 
			
		||||
 | 
			
		||||
        # convert the opencv-projected points to pytorch3d screen coords
 | 
			
		||||
        pts_proj_opencv_in_pytorch3d_screen = _coords_opencv_screen_to_pytorch3d_ndc(
 | 
			
		||||
            pts_proj_opencv, image_size
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # compare to the cached projected points
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            pts_proj_opencv_in_pytorch3d_screen, pts_proj_pytorch3d, atol=1e-5
 | 
			
		||||
        )
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user