mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
rotate_on_spot
Summary: Function to relatively rotate a camera position. Also document how to relatively translate a camera position. Reviewed By: theschnitz Differential Revision: D25900166 fbshipit-source-id: 2ddaf06ee7c5e2a2e973c04d7dee6ccb61c6ff84
This commit is contained in:
parent
e12a08133f
commit
61e38de034
@ -6,6 +6,7 @@ from .blending import (
|
|||||||
sigmoid_alpha_blend,
|
sigmoid_alpha_blend,
|
||||||
softmax_rgb_blend,
|
softmax_rgb_blend,
|
||||||
)
|
)
|
||||||
|
from .camera_utils import rotate_on_spot
|
||||||
from .cameras import OpenGLOrthographicCameras # deprecated
|
from .cameras import OpenGLOrthographicCameras # deprecated
|
||||||
from .cameras import OpenGLPerspectiveCameras # deprecated
|
from .cameras import OpenGLPerspectiveCameras # deprecated
|
||||||
from .cameras import SfMOrthographicCameras # deprecated
|
from .cameras import SfMOrthographicCameras # deprecated
|
||||||
|
@ -62,3 +62,78 @@ def camera_to_eye_at_up(
|
|||||||
eye, at, up_plus_eye = eye_at_up_world.unbind(1)
|
eye, at, up_plus_eye = eye_at_up_world.unbind(1)
|
||||||
up = up_plus_eye - eye
|
up = up_plus_eye - eye
|
||||||
return eye, at, up
|
return eye, at, up
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_on_spot(
|
||||||
|
R: torch.Tensor, T: torch.Tensor, rotation: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Given a camera position as R and T (batched or not),
|
||||||
|
and a rotation matrix (batched or not)
|
||||||
|
return a new R and T representing camera position(s)
|
||||||
|
in the same location but rotated on the spot by the
|
||||||
|
given rotation. In particular the new world to view
|
||||||
|
rotation will be the previous one followed by the inverse
|
||||||
|
of the given rotation.
|
||||||
|
|
||||||
|
For example, adding the following lines before constructing a camera
|
||||||
|
will make the camera point a little to the right of where it
|
||||||
|
otherwise would have been.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
from math import radians
|
||||||
|
from pytorch3d.transforms import axis_angle_to_matrix
|
||||||
|
angles = [0, radians(10), 0]
|
||||||
|
rotation = axis_angle_to_matrix(torch.FloatTensor(angles))
|
||||||
|
R, T = rotate_on_spot(R, T, rotation)
|
||||||
|
|
||||||
|
Note here that if you have a column vector, then when you
|
||||||
|
premultiply it by this `rotation` (see the rotation_conversions doc),
|
||||||
|
then it will be rotated anticlockwise if facing the -y axis.
|
||||||
|
In our context, where we postmultiply row vectors to transform them,
|
||||||
|
`rotation` will rotate the camera clockwise around the -y axis
|
||||||
|
(i.e. when looking down), which is a turn to the right.
|
||||||
|
|
||||||
|
If angles was [radians(10), 0, 0], the camera would get pointed
|
||||||
|
up a bit instead.
|
||||||
|
|
||||||
|
If angles was [0, 0, radians(10)], the camera would be rotated anticlockwise
|
||||||
|
a bit, so the image would appear rotated clockwise from how it
|
||||||
|
otherwise would have been.
|
||||||
|
|
||||||
|
If you want to translate the camera from the origin in camera
|
||||||
|
coordinates, this is simple and does not need a separate function.
|
||||||
|
In particular, a translation by X = [a, b, c] would cause
|
||||||
|
the camera to move a units left, b units up, and c units
|
||||||
|
forward. This is achieved by using T-X in place of T.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
R: FloatTensor of shape [3, 3] or [N, 3, 3]
|
||||||
|
T: FloatTensor of shape [3] or [N, 3]
|
||||||
|
rotation: FloatTensor of shape [3, 3] or [n, 3, 3]
|
||||||
|
where if neither n nor N is 1, then n and N must be equal.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
R: FloatTensor of shape [max(N, n), 3, 3]
|
||||||
|
T: FloatTensor of shape [max(N, n), 3]
|
||||||
|
"""
|
||||||
|
if R.ndim == 2:
|
||||||
|
R = R[None]
|
||||||
|
if T.ndim == 1:
|
||||||
|
T = T[None]
|
||||||
|
if rotation.ndim == 2:
|
||||||
|
rotation = rotation[None]
|
||||||
|
|
||||||
|
if R.ndim != 3 or R.shape[1:] != (3, 3):
|
||||||
|
raise ValueError("Invalid R")
|
||||||
|
if T.ndim != 2 or T.shape[1] != 3:
|
||||||
|
raise ValueError("Invalid T")
|
||||||
|
if rotation.ndim != 3 or rotation.shape[1:] != (3, 3):
|
||||||
|
raise ValueError("Invalid rotation")
|
||||||
|
|
||||||
|
new_R = R @ rotation.transpose(1, 2)
|
||||||
|
old_RT = torch.bmm(R, T[:, :, None])
|
||||||
|
new_T = torch.matmul(new_R.transpose(1, 2), old_RT)[:, :, 0]
|
||||||
|
|
||||||
|
return new_R, new_T
|
||||||
|
@ -38,7 +38,8 @@ class CamerasBase(TensorProperties):
|
|||||||
- Screen coordinate system: This is another representation of the view volume with
|
- Screen coordinate system: This is another representation of the view volume with
|
||||||
the XY coordinates defined in pixel space instead of a normalized space.
|
the XY coordinates defined in pixel space instead of a normalized space.
|
||||||
|
|
||||||
A better illustration of the coordinate systems can be found in pytorch3d/docs/notes/cameras.md.
|
A better illustration of the coordinate systems can be found in
|
||||||
|
pytorch3d/docs/notes/cameras.md.
|
||||||
|
|
||||||
It defines methods that are common to all camera models:
|
It defines methods that are common to all camera models:
|
||||||
- `get_camera_center` that returns the optical center of the camera in
|
- `get_camera_center` that returns the optical center of the camera in
|
||||||
|
@ -1,14 +1,29 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from math import radians
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
from pytorch3d.renderer.camera_utils import camera_to_eye_at_up
|
from pytorch3d.renderer.camera_utils import camera_to_eye_at_up, rotate_on_spot
|
||||||
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
|
from pytorch3d.renderer.cameras import (
|
||||||
|
PerspectiveCameras,
|
||||||
|
get_world_to_view_transform,
|
||||||
|
look_at_view_transform,
|
||||||
|
)
|
||||||
|
from pytorch3d.transforms import axis_angle_to_matrix
|
||||||
from torch.nn.functional import normalize
|
from torch.nn.functional import normalize
|
||||||
|
|
||||||
|
|
||||||
|
def _batched_dotprod(x: torch.Tensor, y: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Takes two tensors of shape (N,3) and returns their batched
|
||||||
|
dot product along the last dimension as a tensor of shape
|
||||||
|
(N,).
|
||||||
|
"""
|
||||||
|
return torch.einsum("ij,ij->i", x, y)
|
||||||
|
|
||||||
|
|
||||||
class TestCameraUtils(TestCaseMixin, unittest.TestCase):
|
class TestCameraUtils(TestCaseMixin, unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
@ -28,6 +43,7 @@ class TestCameraUtils(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# The retrieved eye matches
|
# The retrieved eye matches
|
||||||
self.assertClose(eye, eye2, atol=1e-5)
|
self.assertClose(eye, eye2, atol=1e-5)
|
||||||
|
self.assertClose(cameras.get_camera_center(), eye)
|
||||||
|
|
||||||
# at-eye as retrieved must be a vector in the same direction as
|
# at-eye as retrieved must be a vector in the same direction as
|
||||||
# the original.
|
# the original.
|
||||||
@ -49,3 +65,99 @@ class TestCameraUtils(TestCaseMixin, unittest.TestCase):
|
|||||||
cam_trans2 = cameras2.get_world_to_view_transform()
|
cam_trans2 = cameras2.get_world_to_view_transform()
|
||||||
|
|
||||||
self.assertClose(cam_trans.get_matrix(), cam_trans2.get_matrix(), atol=1e-5)
|
self.assertClose(cam_trans.get_matrix(), cam_trans2.get_matrix(), atol=1e-5)
|
||||||
|
|
||||||
|
def test_rotate_on_spot_yaw(self):
|
||||||
|
N = 14
|
||||||
|
eye = torch.rand(N, 3)
|
||||||
|
at = torch.rand(N, 3)
|
||||||
|
up = torch.rand(N, 3)
|
||||||
|
|
||||||
|
R, T = look_at_view_transform(eye=eye, at=at, up=up)
|
||||||
|
|
||||||
|
# Moving around the y axis looks left.
|
||||||
|
angles = torch.FloatTensor([0, -radians(10), 0])
|
||||||
|
rotation = axis_angle_to_matrix(angles)
|
||||||
|
R_rot, T_rot = rotate_on_spot(R, T, rotation)
|
||||||
|
|
||||||
|
eye_rot, at_rot, up_rot = camera_to_eye_at_up(
|
||||||
|
get_world_to_view_transform(R=R_rot, T=T_rot)
|
||||||
|
)
|
||||||
|
self.assertClose(eye, eye_rot, atol=1e-5)
|
||||||
|
|
||||||
|
# Make vectors pointing exactly left and up
|
||||||
|
left = torch.cross(up, at - eye, dim=-1)
|
||||||
|
left_rot = torch.cross(up_rot, at_rot - eye_rot, dim=-1)
|
||||||
|
fully_up = torch.cross(at - eye, left, dim=-1)
|
||||||
|
fully_up_rot = torch.cross(at_rot - eye_rot, left_rot, dim=-1)
|
||||||
|
|
||||||
|
# The up direction is unchanged
|
||||||
|
self.assertClose(normalize(fully_up), normalize(fully_up_rot), atol=1e-5)
|
||||||
|
|
||||||
|
# The camera has moved left
|
||||||
|
agree = _batched_dotprod(torch.cross(left, left_rot, dim=1), fully_up)
|
||||||
|
self.assertGreater(agree.min(), 0)
|
||||||
|
|
||||||
|
# Batch dimension for rotation
|
||||||
|
R_rot2, T_rot2 = rotate_on_spot(R, T, rotation.expand(N, 3, 3))
|
||||||
|
self.assertClose(R_rot, R_rot2)
|
||||||
|
self.assertClose(T_rot, T_rot2)
|
||||||
|
|
||||||
|
# No batch dimension for either
|
||||||
|
R_rot3, T_rot3 = rotate_on_spot(R[0], T[0], rotation)
|
||||||
|
self.assertClose(R_rot[:1], R_rot3)
|
||||||
|
self.assertClose(T_rot[:1], T_rot3)
|
||||||
|
|
||||||
|
# No batch dimension for R, T
|
||||||
|
R_rot4, T_rot4 = rotate_on_spot(R[0], T[0], rotation.expand(N, 3, 3))
|
||||||
|
self.assertClose(R_rot[:1].expand(N, 3, 3), R_rot4)
|
||||||
|
self.assertClose(T_rot[:1].expand(N, 3), T_rot4)
|
||||||
|
|
||||||
|
def test_rotate_on_spot_pitch(self):
|
||||||
|
N = 14
|
||||||
|
eye = torch.rand(N, 3)
|
||||||
|
at = torch.rand(N, 3)
|
||||||
|
up = torch.rand(N, 3)
|
||||||
|
|
||||||
|
R, T = look_at_view_transform(eye=eye, at=at, up=up)
|
||||||
|
|
||||||
|
# Moving around the x axis looks down.
|
||||||
|
angles = torch.FloatTensor([-radians(10), 0, 0])
|
||||||
|
rotation = axis_angle_to_matrix(angles)
|
||||||
|
R_rot, T_rot = rotate_on_spot(R, T, rotation)
|
||||||
|
eye_rot, at_rot, up_rot = camera_to_eye_at_up(
|
||||||
|
get_world_to_view_transform(R=R_rot, T=T_rot)
|
||||||
|
)
|
||||||
|
self.assertClose(eye, eye_rot, atol=1e-5)
|
||||||
|
|
||||||
|
# A vector pointing left is unchanged
|
||||||
|
left = torch.cross(up, at - eye, dim=-1)
|
||||||
|
left_rot = torch.cross(up_rot, at_rot - eye_rot, dim=-1)
|
||||||
|
self.assertClose(normalize(left), normalize(left_rot), atol=1e-5)
|
||||||
|
|
||||||
|
# The camera has moved down
|
||||||
|
fully_up = torch.cross(at - eye, left, dim=-1)
|
||||||
|
fully_up_rot = torch.cross(at_rot - eye_rot, left_rot, dim=-1)
|
||||||
|
agree = _batched_dotprod(torch.cross(fully_up, fully_up_rot, dim=1), left)
|
||||||
|
self.assertGreater(agree.min(), 0)
|
||||||
|
|
||||||
|
def test_rotate_on_spot_roll(self):
|
||||||
|
N = 14
|
||||||
|
eye = torch.rand(N, 3)
|
||||||
|
at = torch.rand(N, 3)
|
||||||
|
up = torch.rand(N, 3)
|
||||||
|
|
||||||
|
R, T = look_at_view_transform(eye=eye, at=at, up=up)
|
||||||
|
|
||||||
|
# Moving around the z axis rotates the image.
|
||||||
|
angles = torch.FloatTensor([0, 0, -radians(10)])
|
||||||
|
rotation = axis_angle_to_matrix(angles)
|
||||||
|
R_rot, T_rot = rotate_on_spot(R, T, rotation)
|
||||||
|
eye_rot, at_rot, up_rot = camera_to_eye_at_up(
|
||||||
|
get_world_to_view_transform(R=R_rot, T=T_rot)
|
||||||
|
)
|
||||||
|
self.assertClose(eye, eye_rot, atol=1e-5)
|
||||||
|
self.assertClose(normalize(at - eye), normalize(at_rot - eye), atol=1e-5)
|
||||||
|
|
||||||
|
# The camera has moved clockwise
|
||||||
|
agree = _batched_dotprod(torch.cross(up, up_rot, dim=1), at - eye)
|
||||||
|
self.assertGreater(agree.min(), 0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user