From 61e38de034b57f3c703d5049a117764e78f72fe2 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 26 May 2021 04:52:46 -0700 Subject: [PATCH] rotate_on_spot Summary: Function to relatively rotate a camera position. Also document how to relatively translate a camera position. Reviewed By: theschnitz Differential Revision: D25900166 fbshipit-source-id: 2ddaf06ee7c5e2a2e973c04d7dee6ccb61c6ff84 --- pytorch3d/renderer/__init__.py | 1 + pytorch3d/renderer/camera_utils.py | 75 +++++++++++++++++++ pytorch3d/renderer/cameras.py | 3 +- tests/test_camera_utils.py | 116 ++++++++++++++++++++++++++++- 4 files changed, 192 insertions(+), 3 deletions(-) diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index f1f09057..382b0b41 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -6,6 +6,7 @@ from .blending import ( sigmoid_alpha_blend, softmax_rgb_blend, ) +from .camera_utils import rotate_on_spot from .cameras import OpenGLOrthographicCameras # deprecated from .cameras import OpenGLPerspectiveCameras # deprecated from .cameras import SfMOrthographicCameras # deprecated diff --git a/pytorch3d/renderer/camera_utils.py b/pytorch3d/renderer/camera_utils.py index 2444518a..bc1c2007 100644 --- a/pytorch3d/renderer/camera_utils.py +++ b/pytorch3d/renderer/camera_utils.py @@ -62,3 +62,78 @@ def camera_to_eye_at_up( eye, at, up_plus_eye = eye_at_up_world.unbind(1) up = up_plus_eye - eye return eye, at, up + + +def rotate_on_spot( + R: torch.Tensor, T: torch.Tensor, rotation: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Given a camera position as R and T (batched or not), + and a rotation matrix (batched or not) + return a new R and T representing camera position(s) + in the same location but rotated on the spot by the + given rotation. In particular the new world to view + rotation will be the previous one followed by the inverse + of the given rotation. + + For example, adding the following lines before constructing a camera + will make the camera point a little to the right of where it + otherwise would have been. + + .. code-block:: + + from math import radians + from pytorch3d.transforms import axis_angle_to_matrix + angles = [0, radians(10), 0] + rotation = axis_angle_to_matrix(torch.FloatTensor(angles)) + R, T = rotate_on_spot(R, T, rotation) + + Note here that if you have a column vector, then when you + premultiply it by this `rotation` (see the rotation_conversions doc), + then it will be rotated anticlockwise if facing the -y axis. + In our context, where we postmultiply row vectors to transform them, + `rotation` will rotate the camera clockwise around the -y axis + (i.e. when looking down), which is a turn to the right. + + If angles was [radians(10), 0, 0], the camera would get pointed + up a bit instead. + + If angles was [0, 0, radians(10)], the camera would be rotated anticlockwise + a bit, so the image would appear rotated clockwise from how it + otherwise would have been. + + If you want to translate the camera from the origin in camera + coordinates, this is simple and does not need a separate function. + In particular, a translation by X = [a, b, c] would cause + the camera to move a units left, b units up, and c units + forward. This is achieved by using T-X in place of T. + + Args: + R: FloatTensor of shape [3, 3] or [N, 3, 3] + T: FloatTensor of shape [3] or [N, 3] + rotation: FloatTensor of shape [3, 3] or [n, 3, 3] + where if neither n nor N is 1, then n and N must be equal. + + Returns: + R: FloatTensor of shape [max(N, n), 3, 3] + T: FloatTensor of shape [max(N, n), 3] + """ + if R.ndim == 2: + R = R[None] + if T.ndim == 1: + T = T[None] + if rotation.ndim == 2: + rotation = rotation[None] + + if R.ndim != 3 or R.shape[1:] != (3, 3): + raise ValueError("Invalid R") + if T.ndim != 2 or T.shape[1] != 3: + raise ValueError("Invalid T") + if rotation.ndim != 3 or rotation.shape[1:] != (3, 3): + raise ValueError("Invalid rotation") + + new_R = R @ rotation.transpose(1, 2) + old_RT = torch.bmm(R, T[:, :, None]) + new_T = torch.matmul(new_R.transpose(1, 2), old_RT)[:, :, 0] + + return new_R, new_T diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index ede7829e..35cf865c 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -38,7 +38,8 @@ class CamerasBase(TensorProperties): - Screen coordinate system: This is another representation of the view volume with the XY coordinates defined in pixel space instead of a normalized space. - A better illustration of the coordinate systems can be found in pytorch3d/docs/notes/cameras.md. + A better illustration of the coordinate systems can be found in + pytorch3d/docs/notes/cameras.md. It defines methods that are common to all camera models: - `get_camera_center` that returns the optical center of the camera in diff --git a/tests/test_camera_utils.py b/tests/test_camera_utils.py index 1e619067..b0abf130 100644 --- a/tests/test_camera_utils.py +++ b/tests/test_camera_utils.py @@ -1,14 +1,29 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import unittest +from math import radians import torch from common_testing import TestCaseMixin -from pytorch3d.renderer.camera_utils import camera_to_eye_at_up -from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform +from pytorch3d.renderer.camera_utils import camera_to_eye_at_up, rotate_on_spot +from pytorch3d.renderer.cameras import ( + PerspectiveCameras, + get_world_to_view_transform, + look_at_view_transform, +) +from pytorch3d.transforms import axis_angle_to_matrix from torch.nn.functional import normalize +def _batched_dotprod(x: torch.Tensor, y: torch.Tensor): + """ + Takes two tensors of shape (N,3) and returns their batched + dot product along the last dimension as a tensor of shape + (N,). + """ + return torch.einsum("ij,ij->i", x, y) + + class TestCameraUtils(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: torch.manual_seed(42) @@ -28,6 +43,7 @@ class TestCameraUtils(TestCaseMixin, unittest.TestCase): # The retrieved eye matches self.assertClose(eye, eye2, atol=1e-5) + self.assertClose(cameras.get_camera_center(), eye) # at-eye as retrieved must be a vector in the same direction as # the original. @@ -49,3 +65,99 @@ class TestCameraUtils(TestCaseMixin, unittest.TestCase): cam_trans2 = cameras2.get_world_to_view_transform() self.assertClose(cam_trans.get_matrix(), cam_trans2.get_matrix(), atol=1e-5) + + def test_rotate_on_spot_yaw(self): + N = 14 + eye = torch.rand(N, 3) + at = torch.rand(N, 3) + up = torch.rand(N, 3) + + R, T = look_at_view_transform(eye=eye, at=at, up=up) + + # Moving around the y axis looks left. + angles = torch.FloatTensor([0, -radians(10), 0]) + rotation = axis_angle_to_matrix(angles) + R_rot, T_rot = rotate_on_spot(R, T, rotation) + + eye_rot, at_rot, up_rot = camera_to_eye_at_up( + get_world_to_view_transform(R=R_rot, T=T_rot) + ) + self.assertClose(eye, eye_rot, atol=1e-5) + + # Make vectors pointing exactly left and up + left = torch.cross(up, at - eye, dim=-1) + left_rot = torch.cross(up_rot, at_rot - eye_rot, dim=-1) + fully_up = torch.cross(at - eye, left, dim=-1) + fully_up_rot = torch.cross(at_rot - eye_rot, left_rot, dim=-1) + + # The up direction is unchanged + self.assertClose(normalize(fully_up), normalize(fully_up_rot), atol=1e-5) + + # The camera has moved left + agree = _batched_dotprod(torch.cross(left, left_rot, dim=1), fully_up) + self.assertGreater(agree.min(), 0) + + # Batch dimension for rotation + R_rot2, T_rot2 = rotate_on_spot(R, T, rotation.expand(N, 3, 3)) + self.assertClose(R_rot, R_rot2) + self.assertClose(T_rot, T_rot2) + + # No batch dimension for either + R_rot3, T_rot3 = rotate_on_spot(R[0], T[0], rotation) + self.assertClose(R_rot[:1], R_rot3) + self.assertClose(T_rot[:1], T_rot3) + + # No batch dimension for R, T + R_rot4, T_rot4 = rotate_on_spot(R[0], T[0], rotation.expand(N, 3, 3)) + self.assertClose(R_rot[:1].expand(N, 3, 3), R_rot4) + self.assertClose(T_rot[:1].expand(N, 3), T_rot4) + + def test_rotate_on_spot_pitch(self): + N = 14 + eye = torch.rand(N, 3) + at = torch.rand(N, 3) + up = torch.rand(N, 3) + + R, T = look_at_view_transform(eye=eye, at=at, up=up) + + # Moving around the x axis looks down. + angles = torch.FloatTensor([-radians(10), 0, 0]) + rotation = axis_angle_to_matrix(angles) + R_rot, T_rot = rotate_on_spot(R, T, rotation) + eye_rot, at_rot, up_rot = camera_to_eye_at_up( + get_world_to_view_transform(R=R_rot, T=T_rot) + ) + self.assertClose(eye, eye_rot, atol=1e-5) + + # A vector pointing left is unchanged + left = torch.cross(up, at - eye, dim=-1) + left_rot = torch.cross(up_rot, at_rot - eye_rot, dim=-1) + self.assertClose(normalize(left), normalize(left_rot), atol=1e-5) + + # The camera has moved down + fully_up = torch.cross(at - eye, left, dim=-1) + fully_up_rot = torch.cross(at_rot - eye_rot, left_rot, dim=-1) + agree = _batched_dotprod(torch.cross(fully_up, fully_up_rot, dim=1), left) + self.assertGreater(agree.min(), 0) + + def test_rotate_on_spot_roll(self): + N = 14 + eye = torch.rand(N, 3) + at = torch.rand(N, 3) + up = torch.rand(N, 3) + + R, T = look_at_view_transform(eye=eye, at=at, up=up) + + # Moving around the z axis rotates the image. + angles = torch.FloatTensor([0, 0, -radians(10)]) + rotation = axis_angle_to_matrix(angles) + R_rot, T_rot = rotate_on_spot(R, T, rotation) + eye_rot, at_rot, up_rot = camera_to_eye_at_up( + get_world_to_view_transform(R=R_rot, T=T_rot) + ) + self.assertClose(eye, eye_rot, atol=1e-5) + self.assertClose(normalize(at - eye), normalize(at_rot - eye), atol=1e-5) + + # The camera has moved clockwise + agree = _batched_dotprod(torch.cross(up, up_rot, dim=1), at - eye) + self.assertGreater(agree.min(), 0)