mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
work with old linalg
Summary: solve and lstsq have moved around in torch. Cope with both. Reviewed By: patricklabatut Differential Revision: D29302316 fbshipit-source-id: b34f0b923e90a357f20df359635929241eba6e74
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5284de6e97
commit
b8790474f1
51
pytorch3d/common/compat.py
Normal file
51
pytorch3d/common/compat.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""
|
||||
Some functions which depend on PyTorch versions.
|
||||
"""
|
||||
|
||||
|
||||
def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
|
||||
"""
|
||||
Like torch.linalg.solve, tries to return X
|
||||
such that AX=B, with A square.
|
||||
"""
|
||||
if hasattr(torch.linalg, "solve"):
|
||||
# PyTorch version >= 1.8.0
|
||||
return torch.linalg.solve(A, B)
|
||||
|
||||
return torch.solve(B, A).solution
|
||||
|
||||
|
||||
def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
|
||||
"""
|
||||
Like torch.linalg.lstsq, tries to return X
|
||||
such that AX=B.
|
||||
"""
|
||||
if hasattr(torch.linalg, "lstsq"):
|
||||
# PyTorch version >= 1.9
|
||||
return torch.linalg.lstsq(A, B).solution
|
||||
|
||||
solution = torch.lstsq(B, A).solution
|
||||
if A.shape[1] < A.shape[0]:
|
||||
return solution[: A.shape[1]]
|
||||
return solution
|
||||
|
||||
|
||||
def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
|
||||
"""
|
||||
Like torch.linalg.qr.
|
||||
"""
|
||||
if hasattr(torch.linalg, "qr"):
|
||||
# PyTorch version >= 1.9
|
||||
return torch.linalg.qr(A)
|
||||
return torch.qr(A)
|
||||
@@ -26,8 +26,8 @@ from .rotation_conversions import (
|
||||
)
|
||||
from .se3 import se3_exp_map, se3_log_map
|
||||
from .so3 import (
|
||||
so3_exponential_map,
|
||||
so3_exp_map,
|
||||
so3_exponential_map,
|
||||
so3_log_map,
|
||||
so3_relative_angle,
|
||||
so3_rotation_angle,
|
||||
|
||||
@@ -5,8 +5,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.compat import solve
|
||||
|
||||
from .so3 import hat, _so3_exp_map, so3_log_map
|
||||
from .so3 import _so3_exp_map, hat, so3_log_map
|
||||
|
||||
|
||||
def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
|
||||
@@ -173,7 +174,7 @@ def se3_log_map(
|
||||
# log_translation is V^-1 @ T
|
||||
T = transform[:, 3, :3]
|
||||
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
|
||||
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]
|
||||
log_translation = solve(V, T[:, :, None])[:, :, 0]
|
||||
|
||||
return torch.cat((log_translation, log_rotation), dim=1)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import torch
|
||||
|
||||
from ..transforms import acos_linear_extrapolation
|
||||
|
||||
|
||||
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user