mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
(eye, at, up) extraction function
Summary: Plotly viewing from a specific camera location requires converting that location in to an (eye, at, up) specification. There may be other reasons to want to do this as well. I create a separate utility function for it. I envisage more such utility functions for manipulating camera information, so I create a separate camera_utils.py file for such things. Reviewed By: nikhilaravi Differential Revision: D25981184 fbshipit-source-id: 0947bf98b212676c021f2fddf775bf436dee3487
This commit is contained in:
parent
ddebdfbcd7
commit
cf9bb7c48c
64
pytorch3d/renderer/camera_utils.py
Normal file
64
pytorch3d/renderer/camera_utils.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.transforms import Transform3d
|
||||
|
||||
|
||||
def camera_to_eye_at_up(
|
||||
world_to_view_transform: Transform3d,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Given a world to view transform, return the eye, at and up vectors which
|
||||
represent its position.
|
||||
|
||||
For example, if cam is a camera object, then after running
|
||||
|
||||
.. code-block::
|
||||
|
||||
from cameras import look_at_view_transform
|
||||
eye, at, up = camera_to_eye_at_up(cam.get_world_to_view_transform())
|
||||
R, T = look_at_view_transform(eye=eye, at=at, up=up)
|
||||
|
||||
any other camera created from R and T will have the same world to view
|
||||
transform as cam.
|
||||
|
||||
Also, given a camera position R and T, then after running:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from cameras import get_world_to_view_transform, look_at_view_transform
|
||||
eye, at, up = camera_to_eye_at_up(get_world_to_view_transform(R=R, T=T))
|
||||
R2, T2 = look_at_view_transform(eye=eye, at=at, up=up)
|
||||
|
||||
R2 will equal R and T2 will equal T.
|
||||
|
||||
Args:
|
||||
world_to_view_transform: Transform3d representing the extrinsic
|
||||
transformation of N cameras.
|
||||
|
||||
Returns:
|
||||
eye: FloatTensor of shape [N, 3] representing the camera centers in world space.
|
||||
at: FloatTensor of shape [N, 3] representing points in world space directly in
|
||||
front of the cameras e.g. the positions of objects to be viewed by the
|
||||
cameras.
|
||||
up: FloatTensor of shape [N, 3] representing vectors in world space which
|
||||
when projected on to the camera plane point upwards.
|
||||
"""
|
||||
cam_trans = world_to_view_transform.inverse()
|
||||
# In the PyTorch3D right handed coordinate system, the camera in view space
|
||||
# is always at the origin looking along the +z axis.
|
||||
|
||||
# The up vector is not a position so cannot be transformed with
|
||||
# transform_points. However the position eye+up above the camera
|
||||
# (whose position vector in the camera coordinate frame is an up vector)
|
||||
# can be transformed with transform_points.
|
||||
eye_at_up_view = torch.tensor(
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=torch.float32, device=cam_trans.device
|
||||
)
|
||||
eye_at_up_world = cam_trans.transform_points(eye_at_up_view).reshape(-1, 3, 3)
|
||||
|
||||
eye, at, up_plus_eye = eye_at_up_world.unbind(1)
|
||||
up = up_plus_eye - eye
|
||||
return eye, at, up
|
51
tests/test_camera_utils.py
Normal file
51
tests/test_camera_utils.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.camera_utils import camera_to_eye_at_up
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
|
||||
from torch.nn.functional import normalize
|
||||
|
||||
|
||||
class TestCameraUtils(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_invert_eye_at_up(self):
|
||||
# Generate random cameras and check we can reconstruct their eye, at,
|
||||
# and up vectors.
|
||||
N = 13
|
||||
eye = torch.rand(N, 3)
|
||||
at = torch.rand(N, 3)
|
||||
up = torch.rand(N, 3)
|
||||
|
||||
R, T = look_at_view_transform(eye=eye, at=at, up=up)
|
||||
cameras = PerspectiveCameras(R=R, T=T)
|
||||
|
||||
eye2, at2, up2 = camera_to_eye_at_up(cameras.get_world_to_view_transform())
|
||||
|
||||
# The retrieved eye matches
|
||||
self.assertClose(eye, eye2, atol=1e-5)
|
||||
|
||||
# at-eye as retrieved must be a vector in the same direction as
|
||||
# the original.
|
||||
self.assertClose(normalize(at - eye), normalize(at2 - eye2))
|
||||
|
||||
# The up vector as retrieved should be rotated the same amount
|
||||
# around at-eye as the original. The component in the at-eye
|
||||
# direction is unimportant, as is the length.
|
||||
# So check that (up x (at-eye)) as retrieved is in the same
|
||||
# direction as its original value.
|
||||
up_check = torch.cross(up, at - eye, dim=-1)
|
||||
up_check2 = torch.cross(up2, at - eye, dim=-1)
|
||||
self.assertClose(normalize(up_check), normalize(up_check2))
|
||||
|
||||
# Master check that we get the same camera if we reinitialise.
|
||||
R2, T2 = look_at_view_transform(eye=eye2, at=at2, up=up2)
|
||||
cameras2 = PerspectiveCameras(R=R2, T=T2)
|
||||
cam_trans = cameras.get_world_to_view_transform()
|
||||
cam_trans2 = cameras2.get_world_to_view_transform()
|
||||
|
||||
self.assertClose(cam_trans.get_matrix(), cam_trans2.get_matrix(), atol=1e-5)
|
Loading…
x
Reference in New Issue
Block a user