From b8790474f16994d75e6cf64447080f1b52bcc292 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Mon, 28 Jun 2021 06:30:27 -0700 Subject: [PATCH] work with old linalg Summary: solve and lstsq have moved around in torch. Cope with both. Reviewed By: patricklabatut Differential Revision: D29302316 fbshipit-source-id: b34f0b923e90a357f20df359635929241eba6e74 --- pytorch3d/common/compat.py | 51 +++++++++++++++++++++++++ pytorch3d/transforms/__init__.py | 2 +- pytorch3d/transforms/se3.py | 5 ++- pytorch3d/transforms/so3.py | 1 + tests/test_acos_linear_extrapolation.py | 4 +- tests/test_se3.py | 9 ++--- tests/test_so3.py | 7 ++-- 7 files changed, 65 insertions(+), 14 deletions(-) create mode 100644 pytorch3d/common/compat.py diff --git a/pytorch3d/common/compat.py b/pytorch3d/common/compat.py new file mode 100644 index 00000000..f2edde0e --- /dev/null +++ b/pytorch3d/common/compat.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + + +""" +Some functions which depend on PyTorch versions. +""" + + +def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Like torch.linalg.solve, tries to return X + such that AX=B, with A square. + """ + if hasattr(torch.linalg, "solve"): + # PyTorch version >= 1.8.0 + return torch.linalg.solve(A, B) + + return torch.solve(B, A).solution + + +def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Like torch.linalg.lstsq, tries to return X + such that AX=B. + """ + if hasattr(torch.linalg, "lstsq"): + # PyTorch version >= 1.9 + return torch.linalg.lstsq(A, B).solution + + solution = torch.lstsq(B, A).solution + if A.shape[1] < A.shape[0]: + return solution[: A.shape[1]] + return solution + + +def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover + """ + Like torch.linalg.qr. + """ + if hasattr(torch.linalg, "qr"): + # PyTorch version >= 1.9 + return torch.linalg.qr(A) + return torch.qr(A) diff --git a/pytorch3d/transforms/__init__.py b/pytorch3d/transforms/__init__.py index 14d51f0f..448ea66f 100644 --- a/pytorch3d/transforms/__init__.py +++ b/pytorch3d/transforms/__init__.py @@ -26,8 +26,8 @@ from .rotation_conversions import ( ) from .se3 import se3_exp_map, se3_log_map from .so3 import ( - so3_exponential_map, so3_exp_map, + so3_exponential_map, so3_log_map, so3_relative_angle, so3_rotation_angle, diff --git a/pytorch3d/transforms/se3.py b/pytorch3d/transforms/se3.py index fd17eb56..212c277a 100644 --- a/pytorch3d/transforms/se3.py +++ b/pytorch3d/transforms/se3.py @@ -5,8 +5,9 @@ # LICENSE file in the root directory of this source tree. import torch +from pytorch3d.common.compat import solve -from .so3 import hat, _so3_exp_map, so3_log_map +from .so3 import _so3_exp_map, hat, so3_log_map def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor: @@ -173,7 +174,7 @@ def se3_log_map( # log_translation is V^-1 @ T T = transform[:, 3, :3] V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps) - log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0] + log_translation = solve(V, T[:, :, None])[:, :, 0] return torch.cat((log_translation, log_rotation), dim=1) diff --git a/pytorch3d/transforms/so3.py b/pytorch3d/transforms/so3.py index 1a5d5a2b..0c6ced7d 100644 --- a/pytorch3d/transforms/so3.py +++ b/pytorch3d/transforms/so3.py @@ -11,6 +11,7 @@ import torch from ..transforms import acos_linear_extrapolation + HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5 diff --git a/tests/test_acos_linear_extrapolation.py b/tests/test_acos_linear_extrapolation.py index 672fb2bb..35076509 100644 --- a/tests/test_acos_linear_extrapolation.py +++ b/tests/test_acos_linear_extrapolation.py @@ -10,6 +10,7 @@ import unittest import numpy as np import torch from common_testing import TestCaseMixin +from pytorch3d.common.compat import lstsq from pytorch3d.transforms import acos_linear_extrapolation @@ -64,8 +65,7 @@ class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase): bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype) # fit a line: slope * x + bias = y x_1 = torch.stack([x, torch.ones_like(x)], dim=-1) - solution = torch.linalg.lstsq(x_1, y[:, None]).solution - slope, bias = solution.view(-1)[:2] + slope, bias = lstsq(x_1, y[:, None]).view(-1)[:2] desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t ** 2) # test that the desired slope is the same as the fitted one self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2) diff --git a/tests/test_se3.py b/tests/test_se3.py index 40121f27..fc8a0b04 100644 --- a/tests/test_se3.py +++ b/tests/test_se3.py @@ -10,13 +10,10 @@ import unittest import numpy as np import torch from common_testing import TestCaseMixin +from pytorch3d.common.compat import qr from pytorch3d.transforms.rotation_conversions import random_rotations from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map -from pytorch3d.transforms.so3 import ( - so3_exp_map, - so3_log_map, - so3_rotation_angle, -) +from pytorch3d.transforms.so3 import so3_exp_map, so3_log_map, so3_rotation_angle class TestSE3(TestCaseMixin, unittest.TestCase): @@ -201,7 +198,7 @@ class TestSE3(TestCaseMixin, unittest.TestCase): r = [identity, rot180] r.extend( [ - torch.qr(identity + torch.randn_like(identity) * 1e-6)[0] + qr(identity + torch.randn_like(identity) * 1e-6)[0] + float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8 # this adds random noise to the second half # of the random orthogonal matrices to generate diff --git a/tests/test_so3.py b/tests/test_so3.py index 2d129869..414733d8 100644 --- a/tests/test_so3.py +++ b/tests/test_so3.py @@ -11,6 +11,7 @@ import unittest import numpy as np import torch from common_testing import TestCaseMixin +from pytorch3d.common.compat import qr from pytorch3d.transforms.so3 import ( hat, so3_exp_map, @@ -46,7 +47,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase): # TODO(dnovotny): replace with random_rotation from random_rotation.py rot = [] for _ in range(batch_size): - r = torch.qr(torch.randn((3, 3), device=device))[0] + r = qr(torch.randn((3, 3), device=device))[0] f = torch.randint(2, (3,), device=device, dtype=torch.float32) if f.sum() % 2 == 0: f = 1 - f @@ -142,7 +143,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase): # add random rotations and random almost orthonormal matrices r.extend( [ - torch.qr(identity + torch.randn_like(identity) * 1e-4)[0] + qr(identity + torch.randn_like(identity) * 1e-4)[0] + float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3 # this adds random noise to the second half # of the random orthogonal matrices to generate @@ -242,7 +243,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase): r = [identity, rot180] r.extend( [ - torch.qr(identity + torch.randn_like(identity) * 1e-4)[0] + qr(identity + torch.randn_like(identity) * 1e-4)[0] for _ in range(batch_size - 2) ] )