mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Deprecate so3_exponential_map
Summary: Deprecate the `so3_exponential_map()` function in favor of its alias `so3_exp_map()`: this aligns with the naming of `so3_log_map()` and the recently introduced `se3_exp_map()` / `se3_log_map()` pair. Reviewed By: bottler Differential Revision: D29329966 fbshipit-source-id: b6f60b9e86b2995f70b1fbeb16f9feea05c55de9
This commit is contained in:
parent
f593bfd3c2
commit
5284de6e97
@ -41,7 +41,7 @@
|
||||
"Our optimization seeks to align the estimated (orange) cameras with the ground truth (purple) cameras, by minimizing the discrepancies between pairs of relative cameras. Thus, the solution to the problem should look as follows:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"In practice, the camera extrinsics $g_{ij}$ and $g_i$ are represented using objects from the `SfMPerspectiveCameras` class initialized with the corresponding rotation and translation matrices `R_absolute` and `T_absolute` that define the extrinsic parameters $g = (R, T); R \\in SO(3); T \\in \\mathbb{R}^3$. In order to ensure that `R_absolute` is a valid rotation matrix, we represent it using an exponential map (implemented with `so3_exponential_map`) of the axis-angle representation of the rotation `log_R_absolute`.\n",
|
||||
"In practice, the camera extrinsics $g_{ij}$ and $g_i$ are represented using objects from the `SfMPerspectiveCameras` class initialized with the corresponding rotation and translation matrices `R_absolute` and `T_absolute` that define the extrinsic parameters $g = (R, T); R \\in SO(3); T \\in \\mathbb{R}^3$. In order to ensure that `R_absolute` is a valid rotation matrix, we represent it using an exponential map (implemented with `so3_exp_map`) of the axis-angle representation of the rotation `log_R_absolute`.\n",
|
||||
"\n",
|
||||
"Note that the solution to this problem could only be recovered up to an unknown global rigid transformation $g_{glob} \\in SE(3)$. Thus, for simplicity, we assume knowledge of the absolute extrinsics of the first camera $g_0$. We set $g_0$ as a trivial camera $g_0 = (I, \\vec{0})$.\n"
|
||||
]
|
||||
@ -122,7 +122,7 @@
|
||||
"# imports\n",
|
||||
"import torch\n",
|
||||
"from pytorch3d.transforms.so3 import (\n",
|
||||
" so3_exponential_map,\n",
|
||||
" so3_exp_map,\n",
|
||||
" so3_relative_angle,\n",
|
||||
")\n",
|
||||
"from pytorch3d.renderer.cameras import (\n",
|
||||
@ -328,7 +328,7 @@
|
||||
"\n",
|
||||
"As mentioned earlier, `log_R_absolute` is the axis angle representation of the rotation part of our absolute cameras. We can obtain the 3x3 rotation matrix `R_absolute` that corresponds to `log_R_absolute` with:\n",
|
||||
"\n",
|
||||
"`R_absolute = so3_exponential_map(log_R_absolute)`\n"
|
||||
"`R_absolute = so3_exp_map(log_R_absolute)`\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -378,7 +378,7 @@
|
||||
" # compute the absolute camera rotations as \n",
|
||||
" # an exponential map of the logarithms (=axis-angles)\n",
|
||||
" # of the absolute rotations\n",
|
||||
" R_absolute = so3_exponential_map(log_R_absolute * camera_mask)\n",
|
||||
" R_absolute = so3_exp_map(log_R_absolute * camera_mask)\n",
|
||||
"\n",
|
||||
" # get the current absolute cameras\n",
|
||||
" cameras_absolute = SfMPerspectiveCameras(\n",
|
||||
|
@ -95,7 +95,7 @@
|
||||
"\n",
|
||||
"# Data structures and functions for rendering\n",
|
||||
"from pytorch3d.structures import Volumes\n",
|
||||
"from pytorch3d.transforms import so3_exponential_map\n",
|
||||
"from pytorch3d.transforms import so3_exp_map\n",
|
||||
"from pytorch3d.renderer import (\n",
|
||||
" FoVPerspectiveCameras, \n",
|
||||
" NDCGridRaysampler,\n",
|
||||
@ -803,7 +803,7 @@
|
||||
"def generate_rotating_nerf(neural_radiance_field, n_frames = 50):\n",
|
||||
" logRs = torch.zeros(n_frames, 3, device=device)\n",
|
||||
" logRs[:, 1] = torch.linspace(-3.14, 3.14, n_frames, device=device)\n",
|
||||
" Rs = so3_exponential_map(logRs)\n",
|
||||
" Rs = so3_exp_map(logRs)\n",
|
||||
" Ts = torch.zeros(n_frames, 3, device=device)\n",
|
||||
" Ts[:, 2] = 2.7\n",
|
||||
" frames = []\n",
|
||||
|
@ -90,7 +90,7 @@
|
||||
" NDCGridRaysampler,\n",
|
||||
" EmissionAbsorptionRaymarcher\n",
|
||||
")\n",
|
||||
"from pytorch3d.transforms import so3_exponential_map\n",
|
||||
"from pytorch3d.transforms import so3_exp_map\n",
|
||||
"\n",
|
||||
"# add path for demo utils functions \n",
|
||||
"sys.path.append(os.path.abspath(''))\n",
|
||||
@ -405,7 +405,7 @@
|
||||
"def generate_rotating_volume(volume_model, n_frames = 50):\n",
|
||||
" logRs = torch.zeros(n_frames, 3, device=device)\n",
|
||||
" logRs[:, 1] = torch.linspace(0.0, 2.0 * 3.14, n_frames, device=device)\n",
|
||||
" Rs = so3_exponential_map(logRs)\n",
|
||||
" Rs = so3_exp_map(logRs)\n",
|
||||
" Ts = torch.zeros(n_frames, 3, device=device)\n",
|
||||
" Ts[:, 2] = 2.7\n",
|
||||
" frames = []\n",
|
||||
|
@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
@ -134,7 +135,15 @@ def so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor:
|
||||
return _so3_exp_map(log_rot, eps=eps)[0]
|
||||
|
||||
|
||||
so3_exponential_map = so3_exp_map
|
||||
def so3_exponential_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor:
|
||||
warnings.warn(
|
||||
"""so3_exponential_map is deprecated,
|
||||
Use so3_exp_map instead.
|
||||
so3_exponential_map will be removed in future releases.""",
|
||||
PendingDeprecationWarning,
|
||||
)
|
||||
|
||||
return so3_exp_map(log_rot, eps)
|
||||
|
||||
|
||||
def _so3_exp_map(
|
||||
|
@ -9,7 +9,7 @@ from typing import Tuple
|
||||
import torch
|
||||
|
||||
from ..renderer import PerspectiveCameras
|
||||
from ..transforms import so3_exponential_map, so3_log_map
|
||||
from ..transforms import so3_exp_map, so3_log_map
|
||||
|
||||
|
||||
def cameras_from_opencv_projection(
|
||||
@ -51,7 +51,7 @@ def cameras_from_opencv_projection(
|
||||
cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention.
|
||||
"""
|
||||
|
||||
R = so3_exponential_map(rvec)
|
||||
R = so3_exp_map(rvec)
|
||||
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
|
||||
principal_point = camera_matrix[:, :2, 2]
|
||||
|
||||
|
@ -12,7 +12,7 @@ import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_tests_dir
|
||||
from pytorch3d.ops import eyes
|
||||
from pytorch3d.transforms import so3_exponential_map, so3_log_map
|
||||
from pytorch3d.transforms import so3_exp_map, so3_log_map
|
||||
from pytorch3d.utils import (
|
||||
cameras_from_opencv_projection,
|
||||
opencv_from_cameras_projection,
|
||||
@ -33,7 +33,7 @@ def cv2_project_points(pts, rvec, tvec, camera_matrix):
|
||||
"""
|
||||
Reproduces the `cv2.projectPoints` function from OpenCV using PyTorch.
|
||||
"""
|
||||
R = so3_exponential_map(rvec)
|
||||
R = so3_exp_map(rvec)
|
||||
pts_proj_3d = (
|
||||
camera_matrix.bmm(R.bmm(pts.permute(0, 2, 1)) + tvec[:, :, None])
|
||||
).permute(0, 2, 1)
|
||||
|
@ -53,7 +53,7 @@ from pytorch3d.renderer.cameras import (
|
||||
)
|
||||
from pytorch3d.transforms import Transform3d
|
||||
from pytorch3d.transforms.rotation_conversions import random_rotations
|
||||
from pytorch3d.transforms.so3 import so3_exponential_map
|
||||
from pytorch3d.transforms.so3 import so3_exp_map
|
||||
|
||||
|
||||
# Naive function adapted from SoftRasterizer for test purposes.
|
||||
@ -145,7 +145,7 @@ def init_random_cameras(
|
||||
T = torch.randn(batch_size, 3) * 0.03
|
||||
if not random_z:
|
||||
T[:, 2] = 4
|
||||
R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0)
|
||||
R = so3_exp_map(torch.randn(batch_size, 3) * 3.0)
|
||||
cam_params = {"R": R, "T": T}
|
||||
if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras):
|
||||
cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
|
||||
@ -509,7 +509,7 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
|
||||
def init_equiv_cameras_ndc_screen(cam_type: CamerasBase, batch_size: int):
|
||||
T = torch.randn(batch_size, 3) * 0.03
|
||||
T[:, 2] = 4
|
||||
R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0)
|
||||
R = so3_exp_map(torch.randn(batch_size, 3) * 3.0)
|
||||
screen_cam_params = {"R": R, "T": T}
|
||||
ndc_cam_params = {"R": R, "T": T}
|
||||
if cam_type in (OrthographicCameras, PerspectiveCameras):
|
||||
|
@ -17,7 +17,7 @@ from pytorch3d.renderer.cameras import (
|
||||
SfMPerspectiveCameras,
|
||||
)
|
||||
from pytorch3d.transforms.rotation_conversions import random_rotations
|
||||
from pytorch3d.transforms.so3 import so3_exponential_map, so3_relative_angle
|
||||
from pytorch3d.transforms.so3 import so3_exp_map, so3_relative_angle
|
||||
from test_cameras import init_random_cameras
|
||||
|
||||
|
||||
@ -95,9 +95,7 @@ class TestCamerasAlignment(TestCaseMixin, unittest.TestCase):
|
||||
) * s_align_gt
|
||||
|
||||
if add_noise != 0.0:
|
||||
R_new = torch.bmm(
|
||||
R_new, so3_exponential_map(torch.randn_like(T_new) * add_noise)
|
||||
)
|
||||
R_new = torch.bmm(R_new, so3_exp_map(torch.randn_like(T_new) * add_noise))
|
||||
T_new += torch.randn_like(T_new) * add_noise
|
||||
|
||||
# create new cameras from R_new and T_new
|
||||
|
@ -15,7 +15,7 @@ from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from pytorch3d.structures.volumes import Volumes
|
||||
from pytorch3d.transforms.so3 import so3_exponential_map
|
||||
from pytorch3d.transforms.so3 import so3_exp_map
|
||||
|
||||
|
||||
DEBUG = False
|
||||
@ -138,7 +138,7 @@ def init_uniform_y_rotations(batch_size: int = 10):
|
||||
angles = torch.linspace(0, 2.0 * np.pi, batch_size + 1, device=device)
|
||||
angles = angles[:batch_size]
|
||||
log_rots = axis[None, :] * angles[:, None]
|
||||
R = so3_exponential_map(log_rots)
|
||||
R = so3_exp_map(log_rots)
|
||||
return R
|
||||
|
||||
|
||||
|
@ -29,7 +29,7 @@ from pytorch3d.renderer import (
|
||||
)
|
||||
from pytorch3d.renderer.cameras import get_world_to_view_transform
|
||||
from pytorch3d.transforms import Transform3d
|
||||
from pytorch3d.transforms.so3 import so3_exponential_map
|
||||
from pytorch3d.transforms.so3 import so3_exp_map
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
@ -316,7 +316,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
# Test get_world_to_view_transform.
|
||||
T = torch.randn(10, 3)
|
||||
R = so3_exponential_map(torch.randn(10, 3) * 3.0)
|
||||
R = so3_exp_map(torch.randn(10, 3) * 3.0)
|
||||
RT = get_world_to_view_transform(R=R, T=T)
|
||||
cam = BlenderCamera(R=R, T=T)
|
||||
RT_class = cam.get_world_to_view_transform()
|
||||
|
@ -10,7 +10,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.transforms.so3 import so3_exponential_map
|
||||
from pytorch3d.transforms.so3 import so3_exp_map
|
||||
from pytorch3d.transforms.transform3d import (
|
||||
Rotate,
|
||||
RotateAxisAngle,
|
||||
@ -146,7 +146,7 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
||||
|
||||
def test_rotate(self):
|
||||
R = so3_exponential_map(torch.randn((1, 3)))
|
||||
R = so3_exp_map(torch.randn((1, 3)))
|
||||
t = Transform3d().rotate(R)
|
||||
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
|
||||
1, 3, 3
|
||||
@ -273,7 +273,7 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
elif choice <= 2.0 / 3.0:
|
||||
t_ = Rotate(
|
||||
so3_exponential_map(
|
||||
so3_exp_map(
|
||||
torch.randn(
|
||||
(batch_size, 3), dtype=torch.float32, device=device
|
||||
)
|
||||
@ -894,7 +894,7 @@ class TestRotate(unittest.TestCase):
|
||||
def test_inverse(self, batch_size=5):
|
||||
device = torch.device("cuda:0")
|
||||
log_rot = torch.randn((batch_size, 3), dtype=torch.float32, device=device)
|
||||
R = so3_exponential_map(log_rot)
|
||||
R = so3_exp_map(log_rot)
|
||||
t = Rotate(R)
|
||||
im = t.inverse()._matrix
|
||||
im_2 = t._matrix.inverse()
|
||||
|
Loading…
x
Reference in New Issue
Block a user