Remove pytorch3d's wrappers for eigh, solve, lstsq, qr

Summary: Remove the compat functions eigh, solve, lstsq, and qr. Migrate callers to use torch.linalg directly.

Reviewed By: bottler

Differential Revision: D39172949

fbshipit-source-id: 484230a553237808f06ee5cdfde64651cba91c4c
This commit is contained in:
Chris Lambert
2022-08-31 13:04:07 -07:00
committed by Facebook GitHub Bot
parent 9a1213e0e5
commit d4a1051e0f
9 changed files with 11 additions and 67 deletions

View File

@@ -14,53 +14,6 @@ Some functions which depend on PyTorch or Python versions.
"""
def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
"""
Like torch.linalg.solve, tries to return X
such that AX=B, with A square.
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "solve"):
# PyTorch version >= 1.8.0
return torch.linalg.solve(A, B)
# pyre-fixme[16]: `Tuple` has no attribute `solution`.
return torch.solve(B, A).solution
def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
"""
Like torch.linalg.lstsq, tries to return X
such that AX=B.
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "lstsq"):
# PyTorch version >= 1.9
return torch.linalg.lstsq(A, B).solution
solution = torch.lstsq(B, A).solution
if A.shape[1] < A.shape[0]:
return solution[: A.shape[1]]
return solution
def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
"""
Like torch.linalg.qr.
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "qr"):
# PyTorch version >= 1.9
return torch.linalg.qr(A)
return torch.qr(A)
def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
"""
Like torch.linalg.eigh, assuming the argument is a symmetric real matrix.
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
return torch.linalg.eigh(A)
return torch.symeig(A, eigenvectors=True)
def meshgrid_ij(
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, ...]: # pragma: no cover

View File

@@ -10,7 +10,6 @@ from math import pi
from typing import Optional
import torch
from pytorch3d.common.compat import eigh, lstsq
def _get_rotation_to_best_fit_xy(
@@ -28,7 +27,7 @@ def _get_rotation_to_best_fit_xy(
(3,3) tensor rotation matrix
"""
points_centered = points - centroid[None]
return eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
return torch.linalg.eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
def _signed_area(path: torch.Tensor) -> torch.Tensor:
@@ -106,9 +105,8 @@ def fit_circle_in_2d(
n_provided = points2d.shape[0]
if n_provided < 3:
raise ValueError(f"{n_provided} points are not enough to determine a circle")
solution = lstsq(design, rhs[:, None])
solution = torch.linalg.lstsq(design, rhs[:, None]).solution
center = solution[:2, 0] / 2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
radius = torch.sqrt(solution[2, 0] + (center**2).sum())
if n_points > 0:
if angles is not None:

View File

@@ -9,7 +9,6 @@ import math
from typing import Optional, Tuple
import torch
from pytorch3d.common.compat import eigh
from pytorch3d.implicitron.tools import utils
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
@@ -205,7 +204,7 @@ def _disambiguate_normal(normal, up):
def _fit_plane(x):
x = x - x.mean(dim=0)[None]
cov = (x.t() @ x) / x.shape[0]
_, e_vec = eigh(cov)
_, e_vec = torch.linalg.eigh(cov)
return e_vec

View File

@@ -16,7 +16,6 @@ from typing import NamedTuple, Optional
import torch
import torch.nn.functional as F
from pytorch3d.common.compat import eigh
from pytorch3d.ops import points_alignment, utils as oputil
@@ -106,7 +105,7 @@ def _null_space(m, kernel_dim):
kernel vectors, of size B x kernel_dim
"""
mTm = torch.bmm(m.transpose(1, 2), m)
s, v = eigh(mTm)
s, v = torch.linalg.eigh(mTm)
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]

View File

@@ -7,7 +7,6 @@
from typing import Tuple, TYPE_CHECKING, Union
import torch
from pytorch3d.common.compat import eigh
from pytorch3d.common.workaround import symeig3x3
from .utils import convert_pointclouds_to_tensor, get_point_covariances
@@ -147,7 +146,7 @@ def estimate_pointcloud_local_coord_frames(
if use_symeig_workaround:
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
else:
curvatures, local_coord_frames = eigh(cov)
curvatures, local_coord_frames = torch.linalg.eigh(cov)
# disambiguate the directions of individual principal vectors
if disambiguate_directions:

View File

@@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
import torch
from pytorch3d.common.compat import solve
from .so3 import _so3_exp_map, hat, so3_log_map
@@ -174,7 +173,7 @@ def se3_log_map(
# log_translation is V^-1 @ T
T = transform[:, 3, :3]
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
log_translation = solve(V, T[:, :, None])[:, :, 0]
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]
return torch.cat((log_translation, log_rotation), dim=1)