mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
work with old linalg
Summary: solve and lstsq have moved around in torch. Cope with both. Reviewed By: patricklabatut Differential Revision: D29302316 fbshipit-source-id: b34f0b923e90a357f20df359635929241eba6e74
This commit is contained in:
parent
5284de6e97
commit
b8790474f1
51
pytorch3d/common/compat.py
Normal file
51
pytorch3d/common/compat.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Some functions which depend on PyTorch versions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
|
||||||
|
"""
|
||||||
|
Like torch.linalg.solve, tries to return X
|
||||||
|
such that AX=B, with A square.
|
||||||
|
"""
|
||||||
|
if hasattr(torch.linalg, "solve"):
|
||||||
|
# PyTorch version >= 1.8.0
|
||||||
|
return torch.linalg.solve(A, B)
|
||||||
|
|
||||||
|
return torch.solve(B, A).solution
|
||||||
|
|
||||||
|
|
||||||
|
def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
|
||||||
|
"""
|
||||||
|
Like torch.linalg.lstsq, tries to return X
|
||||||
|
such that AX=B.
|
||||||
|
"""
|
||||||
|
if hasattr(torch.linalg, "lstsq"):
|
||||||
|
# PyTorch version >= 1.9
|
||||||
|
return torch.linalg.lstsq(A, B).solution
|
||||||
|
|
||||||
|
solution = torch.lstsq(B, A).solution
|
||||||
|
if A.shape[1] < A.shape[0]:
|
||||||
|
return solution[: A.shape[1]]
|
||||||
|
return solution
|
||||||
|
|
||||||
|
|
||||||
|
def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
|
||||||
|
"""
|
||||||
|
Like torch.linalg.qr.
|
||||||
|
"""
|
||||||
|
if hasattr(torch.linalg, "qr"):
|
||||||
|
# PyTorch version >= 1.9
|
||||||
|
return torch.linalg.qr(A)
|
||||||
|
return torch.qr(A)
|
@ -26,8 +26,8 @@ from .rotation_conversions import (
|
|||||||
)
|
)
|
||||||
from .se3 import se3_exp_map, se3_log_map
|
from .se3 import se3_exp_map, se3_log_map
|
||||||
from .so3 import (
|
from .so3 import (
|
||||||
so3_exponential_map,
|
|
||||||
so3_exp_map,
|
so3_exp_map,
|
||||||
|
so3_exponential_map,
|
||||||
so3_log_map,
|
so3_log_map,
|
||||||
so3_relative_angle,
|
so3_relative_angle,
|
||||||
so3_rotation_angle,
|
so3_rotation_angle,
|
||||||
|
@ -5,8 +5,9 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.common.compat import solve
|
||||||
|
|
||||||
from .so3 import hat, _so3_exp_map, so3_log_map
|
from .so3 import _so3_exp_map, hat, so3_log_map
|
||||||
|
|
||||||
|
|
||||||
def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
|
def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
|
||||||
@ -173,7 +174,7 @@ def se3_log_map(
|
|||||||
# log_translation is V^-1 @ T
|
# log_translation is V^-1 @ T
|
||||||
T = transform[:, 3, :3]
|
T = transform[:, 3, :3]
|
||||||
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
|
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
|
||||||
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]
|
log_translation = solve(V, T[:, :, None])[:, :, 0]
|
||||||
|
|
||||||
return torch.cat((log_translation, log_rotation), dim=1)
|
return torch.cat((log_translation, log_rotation), dim=1)
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ import torch
|
|||||||
|
|
||||||
from ..transforms import acos_linear_extrapolation
|
from ..transforms import acos_linear_extrapolation
|
||||||
|
|
||||||
|
|
||||||
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
|
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
|
from pytorch3d.common.compat import lstsq
|
||||||
from pytorch3d.transforms import acos_linear_extrapolation
|
from pytorch3d.transforms import acos_linear_extrapolation
|
||||||
|
|
||||||
|
|
||||||
@ -64,8 +65,7 @@ class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase):
|
|||||||
bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype)
|
bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype)
|
||||||
# fit a line: slope * x + bias = y
|
# fit a line: slope * x + bias = y
|
||||||
x_1 = torch.stack([x, torch.ones_like(x)], dim=-1)
|
x_1 = torch.stack([x, torch.ones_like(x)], dim=-1)
|
||||||
solution = torch.linalg.lstsq(x_1, y[:, None]).solution
|
slope, bias = lstsq(x_1, y[:, None]).view(-1)[:2]
|
||||||
slope, bias = solution.view(-1)[:2]
|
|
||||||
desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t ** 2)
|
desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t ** 2)
|
||||||
# test that the desired slope is the same as the fitted one
|
# test that the desired slope is the same as the fitted one
|
||||||
self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2)
|
self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2)
|
||||||
|
@ -10,13 +10,10 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
|
from pytorch3d.common.compat import qr
|
||||||
from pytorch3d.transforms.rotation_conversions import random_rotations
|
from pytorch3d.transforms.rotation_conversions import random_rotations
|
||||||
from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map
|
from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map
|
||||||
from pytorch3d.transforms.so3 import (
|
from pytorch3d.transforms.so3 import so3_exp_map, so3_log_map, so3_rotation_angle
|
||||||
so3_exp_map,
|
|
||||||
so3_log_map,
|
|
||||||
so3_rotation_angle,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSE3(TestCaseMixin, unittest.TestCase):
|
class TestSE3(TestCaseMixin, unittest.TestCase):
|
||||||
@ -201,7 +198,7 @@ class TestSE3(TestCaseMixin, unittest.TestCase):
|
|||||||
r = [identity, rot180]
|
r = [identity, rot180]
|
||||||
r.extend(
|
r.extend(
|
||||||
[
|
[
|
||||||
torch.qr(identity + torch.randn_like(identity) * 1e-6)[0]
|
qr(identity + torch.randn_like(identity) * 1e-6)[0]
|
||||||
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8
|
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8
|
||||||
# this adds random noise to the second half
|
# this adds random noise to the second half
|
||||||
# of the random orthogonal matrices to generate
|
# of the random orthogonal matrices to generate
|
||||||
|
@ -11,6 +11,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
|
from pytorch3d.common.compat import qr
|
||||||
from pytorch3d.transforms.so3 import (
|
from pytorch3d.transforms.so3 import (
|
||||||
hat,
|
hat,
|
||||||
so3_exp_map,
|
so3_exp_map,
|
||||||
@ -46,7 +47,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
|||||||
# TODO(dnovotny): replace with random_rotation from random_rotation.py
|
# TODO(dnovotny): replace with random_rotation from random_rotation.py
|
||||||
rot = []
|
rot = []
|
||||||
for _ in range(batch_size):
|
for _ in range(batch_size):
|
||||||
r = torch.qr(torch.randn((3, 3), device=device))[0]
|
r = qr(torch.randn((3, 3), device=device))[0]
|
||||||
f = torch.randint(2, (3,), device=device, dtype=torch.float32)
|
f = torch.randint(2, (3,), device=device, dtype=torch.float32)
|
||||||
if f.sum() % 2 == 0:
|
if f.sum() % 2 == 0:
|
||||||
f = 1 - f
|
f = 1 - f
|
||||||
@ -142,7 +143,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
|||||||
# add random rotations and random almost orthonormal matrices
|
# add random rotations and random almost orthonormal matrices
|
||||||
r.extend(
|
r.extend(
|
||||||
[
|
[
|
||||||
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||||
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
|
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
|
||||||
# this adds random noise to the second half
|
# this adds random noise to the second half
|
||||||
# of the random orthogonal matrices to generate
|
# of the random orthogonal matrices to generate
|
||||||
@ -242,7 +243,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
|||||||
r = [identity, rot180]
|
r = [identity, rot180]
|
||||||
r.extend(
|
r.extend(
|
||||||
[
|
[
|
||||||
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||||
for _ in range(batch_size - 2)
|
for _ in range(batch_size - 2)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user