rotate_on_spot

Summary: Function to relatively rotate a camera position. Also document how to relatively translate a camera position.

Reviewed By: theschnitz

Differential Revision: D25900166

fbshipit-source-id: 2ddaf06ee7c5e2a2e973c04d7dee6ccb61c6ff84
This commit is contained in:
Jeremy Reizenstein 2021-05-26 04:52:46 -07:00 committed by Facebook GitHub Bot
parent e12a08133f
commit 61e38de034
4 changed files with 192 additions and 3 deletions

View File

@ -6,6 +6,7 @@ from .blending import (
sigmoid_alpha_blend,
softmax_rgb_blend,
)
from .camera_utils import rotate_on_spot
from .cameras import OpenGLOrthographicCameras # deprecated
from .cameras import OpenGLPerspectiveCameras # deprecated
from .cameras import SfMOrthographicCameras # deprecated

View File

@ -62,3 +62,78 @@ def camera_to_eye_at_up(
eye, at, up_plus_eye = eye_at_up_world.unbind(1)
up = up_plus_eye - eye
return eye, at, up
def rotate_on_spot(
R: torch.Tensor, T: torch.Tensor, rotation: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a camera position as R and T (batched or not),
and a rotation matrix (batched or not)
return a new R and T representing camera position(s)
in the same location but rotated on the spot by the
given rotation. In particular the new world to view
rotation will be the previous one followed by the inverse
of the given rotation.
For example, adding the following lines before constructing a camera
will make the camera point a little to the right of where it
otherwise would have been.
.. code-block::
from math import radians
from pytorch3d.transforms import axis_angle_to_matrix
angles = [0, radians(10), 0]
rotation = axis_angle_to_matrix(torch.FloatTensor(angles))
R, T = rotate_on_spot(R, T, rotation)
Note here that if you have a column vector, then when you
premultiply it by this `rotation` (see the rotation_conversions doc),
then it will be rotated anticlockwise if facing the -y axis.
In our context, where we postmultiply row vectors to transform them,
`rotation` will rotate the camera clockwise around the -y axis
(i.e. when looking down), which is a turn to the right.
If angles was [radians(10), 0, 0], the camera would get pointed
up a bit instead.
If angles was [0, 0, radians(10)], the camera would be rotated anticlockwise
a bit, so the image would appear rotated clockwise from how it
otherwise would have been.
If you want to translate the camera from the origin in camera
coordinates, this is simple and does not need a separate function.
In particular, a translation by X = [a, b, c] would cause
the camera to move a units left, b units up, and c units
forward. This is achieved by using T-X in place of T.
Args:
R: FloatTensor of shape [3, 3] or [N, 3, 3]
T: FloatTensor of shape [3] or [N, 3]
rotation: FloatTensor of shape [3, 3] or [n, 3, 3]
where if neither n nor N is 1, then n and N must be equal.
Returns:
R: FloatTensor of shape [max(N, n), 3, 3]
T: FloatTensor of shape [max(N, n), 3]
"""
if R.ndim == 2:
R = R[None]
if T.ndim == 1:
T = T[None]
if rotation.ndim == 2:
rotation = rotation[None]
if R.ndim != 3 or R.shape[1:] != (3, 3):
raise ValueError("Invalid R")
if T.ndim != 2 or T.shape[1] != 3:
raise ValueError("Invalid T")
if rotation.ndim != 3 or rotation.shape[1:] != (3, 3):
raise ValueError("Invalid rotation")
new_R = R @ rotation.transpose(1, 2)
old_RT = torch.bmm(R, T[:, :, None])
new_T = torch.matmul(new_R.transpose(1, 2), old_RT)[:, :, 0]
return new_R, new_T

View File

@ -38,7 +38,8 @@ class CamerasBase(TensorProperties):
- Screen coordinate system: This is another representation of the view volume with
the XY coordinates defined in pixel space instead of a normalized space.
A better illustration of the coordinate systems can be found in pytorch3d/docs/notes/cameras.md.
A better illustration of the coordinate systems can be found in
pytorch3d/docs/notes/cameras.md.
It defines methods that are common to all camera models:
- `get_camera_center` that returns the optical center of the camera in

View File

@ -1,14 +1,29 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from math import radians
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.camera_utils import camera_to_eye_at_up
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.camera_utils import camera_to_eye_at_up, rotate_on_spot
from pytorch3d.renderer.cameras import (
PerspectiveCameras,
get_world_to_view_transform,
look_at_view_transform,
)
from pytorch3d.transforms import axis_angle_to_matrix
from torch.nn.functional import normalize
def _batched_dotprod(x: torch.Tensor, y: torch.Tensor):
"""
Takes two tensors of shape (N,3) and returns their batched
dot product along the last dimension as a tensor of shape
(N,).
"""
return torch.einsum("ij,ij->i", x, y)
class TestCameraUtils(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
@ -28,6 +43,7 @@ class TestCameraUtils(TestCaseMixin, unittest.TestCase):
# The retrieved eye matches
self.assertClose(eye, eye2, atol=1e-5)
self.assertClose(cameras.get_camera_center(), eye)
# at-eye as retrieved must be a vector in the same direction as
# the original.
@ -49,3 +65,99 @@ class TestCameraUtils(TestCaseMixin, unittest.TestCase):
cam_trans2 = cameras2.get_world_to_view_transform()
self.assertClose(cam_trans.get_matrix(), cam_trans2.get_matrix(), atol=1e-5)
def test_rotate_on_spot_yaw(self):
N = 14
eye = torch.rand(N, 3)
at = torch.rand(N, 3)
up = torch.rand(N, 3)
R, T = look_at_view_transform(eye=eye, at=at, up=up)
# Moving around the y axis looks left.
angles = torch.FloatTensor([0, -radians(10), 0])
rotation = axis_angle_to_matrix(angles)
R_rot, T_rot = rotate_on_spot(R, T, rotation)
eye_rot, at_rot, up_rot = camera_to_eye_at_up(
get_world_to_view_transform(R=R_rot, T=T_rot)
)
self.assertClose(eye, eye_rot, atol=1e-5)
# Make vectors pointing exactly left and up
left = torch.cross(up, at - eye, dim=-1)
left_rot = torch.cross(up_rot, at_rot - eye_rot, dim=-1)
fully_up = torch.cross(at - eye, left, dim=-1)
fully_up_rot = torch.cross(at_rot - eye_rot, left_rot, dim=-1)
# The up direction is unchanged
self.assertClose(normalize(fully_up), normalize(fully_up_rot), atol=1e-5)
# The camera has moved left
agree = _batched_dotprod(torch.cross(left, left_rot, dim=1), fully_up)
self.assertGreater(agree.min(), 0)
# Batch dimension for rotation
R_rot2, T_rot2 = rotate_on_spot(R, T, rotation.expand(N, 3, 3))
self.assertClose(R_rot, R_rot2)
self.assertClose(T_rot, T_rot2)
# No batch dimension for either
R_rot3, T_rot3 = rotate_on_spot(R[0], T[0], rotation)
self.assertClose(R_rot[:1], R_rot3)
self.assertClose(T_rot[:1], T_rot3)
# No batch dimension for R, T
R_rot4, T_rot4 = rotate_on_spot(R[0], T[0], rotation.expand(N, 3, 3))
self.assertClose(R_rot[:1].expand(N, 3, 3), R_rot4)
self.assertClose(T_rot[:1].expand(N, 3), T_rot4)
def test_rotate_on_spot_pitch(self):
N = 14
eye = torch.rand(N, 3)
at = torch.rand(N, 3)
up = torch.rand(N, 3)
R, T = look_at_view_transform(eye=eye, at=at, up=up)
# Moving around the x axis looks down.
angles = torch.FloatTensor([-radians(10), 0, 0])
rotation = axis_angle_to_matrix(angles)
R_rot, T_rot = rotate_on_spot(R, T, rotation)
eye_rot, at_rot, up_rot = camera_to_eye_at_up(
get_world_to_view_transform(R=R_rot, T=T_rot)
)
self.assertClose(eye, eye_rot, atol=1e-5)
# A vector pointing left is unchanged
left = torch.cross(up, at - eye, dim=-1)
left_rot = torch.cross(up_rot, at_rot - eye_rot, dim=-1)
self.assertClose(normalize(left), normalize(left_rot), atol=1e-5)
# The camera has moved down
fully_up = torch.cross(at - eye, left, dim=-1)
fully_up_rot = torch.cross(at_rot - eye_rot, left_rot, dim=-1)
agree = _batched_dotprod(torch.cross(fully_up, fully_up_rot, dim=1), left)
self.assertGreater(agree.min(), 0)
def test_rotate_on_spot_roll(self):
N = 14
eye = torch.rand(N, 3)
at = torch.rand(N, 3)
up = torch.rand(N, 3)
R, T = look_at_view_transform(eye=eye, at=at, up=up)
# Moving around the z axis rotates the image.
angles = torch.FloatTensor([0, 0, -radians(10)])
rotation = axis_angle_to_matrix(angles)
R_rot, T_rot = rotate_on_spot(R, T, rotation)
eye_rot, at_rot, up_rot = camera_to_eye_at_up(
get_world_to_view_transform(R=R_rot, T=T_rot)
)
self.assertClose(eye, eye_rot, atol=1e-5)
self.assertClose(normalize(at - eye), normalize(at_rot - eye), atol=1e-5)
# The camera has moved clockwise
agree = _batched_dotprod(torch.cross(up, up_rot, dim=1), at - eye)
self.assertGreater(agree.min(), 0)