mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Remove pytorch3d's wrappers for eigh, solve, lstsq, qr
Summary: Remove the compat functions eigh, solve, lstsq, and qr. Migrate callers to use torch.linalg directly. Reviewed By: bottler Differential Revision: D39172949 fbshipit-source-id: 484230a553237808f06ee5cdfde64651cba91c4c
This commit is contained in:
parent
9a1213e0e5
commit
d4a1051e0f
@ -14,53 +14,6 @@ Some functions which depend on PyTorch or Python versions.
|
||||
"""
|
||||
|
||||
|
||||
def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
|
||||
"""
|
||||
Like torch.linalg.solve, tries to return X
|
||||
such that AX=B, with A square.
|
||||
"""
|
||||
if hasattr(torch, "linalg") and hasattr(torch.linalg, "solve"):
|
||||
# PyTorch version >= 1.8.0
|
||||
return torch.linalg.solve(A, B)
|
||||
|
||||
# pyre-fixme[16]: `Tuple` has no attribute `solution`.
|
||||
return torch.solve(B, A).solution
|
||||
|
||||
|
||||
def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
|
||||
"""
|
||||
Like torch.linalg.lstsq, tries to return X
|
||||
such that AX=B.
|
||||
"""
|
||||
if hasattr(torch, "linalg") and hasattr(torch.linalg, "lstsq"):
|
||||
# PyTorch version >= 1.9
|
||||
return torch.linalg.lstsq(A, B).solution
|
||||
|
||||
solution = torch.lstsq(B, A).solution
|
||||
if A.shape[1] < A.shape[0]:
|
||||
return solution[: A.shape[1]]
|
||||
return solution
|
||||
|
||||
|
||||
def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
|
||||
"""
|
||||
Like torch.linalg.qr.
|
||||
"""
|
||||
if hasattr(torch, "linalg") and hasattr(torch.linalg, "qr"):
|
||||
# PyTorch version >= 1.9
|
||||
return torch.linalg.qr(A)
|
||||
return torch.qr(A)
|
||||
|
||||
|
||||
def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
|
||||
"""
|
||||
Like torch.linalg.eigh, assuming the argument is a symmetric real matrix.
|
||||
"""
|
||||
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
|
||||
return torch.linalg.eigh(A)
|
||||
return torch.symeig(A, eigenvectors=True)
|
||||
|
||||
|
||||
def meshgrid_ij(
|
||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
|
||||
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
|
||||
|
@ -10,7 +10,6 @@ from math import pi
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.compat import eigh, lstsq
|
||||
|
||||
|
||||
def _get_rotation_to_best_fit_xy(
|
||||
@ -28,7 +27,7 @@ def _get_rotation_to_best_fit_xy(
|
||||
(3,3) tensor rotation matrix
|
||||
"""
|
||||
points_centered = points - centroid[None]
|
||||
return eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
|
||||
return torch.linalg.eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
|
||||
|
||||
|
||||
def _signed_area(path: torch.Tensor) -> torch.Tensor:
|
||||
@ -106,9 +105,8 @@ def fit_circle_in_2d(
|
||||
n_provided = points2d.shape[0]
|
||||
if n_provided < 3:
|
||||
raise ValueError(f"{n_provided} points are not enough to determine a circle")
|
||||
solution = lstsq(design, rhs[:, None])
|
||||
solution = torch.linalg.lstsq(design, rhs[:, None]).solution
|
||||
center = solution[:2, 0] / 2
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
radius = torch.sqrt(solution[2, 0] + (center**2).sum())
|
||||
if n_points > 0:
|
||||
if angles is not None:
|
||||
|
@ -9,7 +9,6 @@ import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.compat import eigh
|
||||
from pytorch3d.implicitron.tools import utils
|
||||
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
|
||||
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
|
||||
@ -205,7 +204,7 @@ def _disambiguate_normal(normal, up):
|
||||
def _fit_plane(x):
|
||||
x = x - x.mean(dim=0)[None]
|
||||
cov = (x.t() @ x) / x.shape[0]
|
||||
_, e_vec = eigh(cov)
|
||||
_, e_vec = torch.linalg.eigh(cov)
|
||||
return e_vec
|
||||
|
||||
|
||||
|
@ -16,7 +16,6 @@ from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.common.compat import eigh
|
||||
from pytorch3d.ops import points_alignment, utils as oputil
|
||||
|
||||
|
||||
@ -106,7 +105,7 @@ def _null_space(m, kernel_dim):
|
||||
kernel vectors, of size B x kernel_dim
|
||||
"""
|
||||
mTm = torch.bmm(m.transpose(1, 2), m)
|
||||
s, v = eigh(mTm)
|
||||
s, v = torch.linalg.eigh(mTm)
|
||||
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
|
||||
|
||||
|
||||
|
@ -7,7 +7,6 @@
|
||||
from typing import Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.compat import eigh
|
||||
from pytorch3d.common.workaround import symeig3x3
|
||||
|
||||
from .utils import convert_pointclouds_to_tensor, get_point_covariances
|
||||
@ -147,7 +146,7 @@ def estimate_pointcloud_local_coord_frames(
|
||||
if use_symeig_workaround:
|
||||
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
|
||||
else:
|
||||
curvatures, local_coord_frames = eigh(cov)
|
||||
curvatures, local_coord_frames = torch.linalg.eigh(cov)
|
||||
|
||||
# disambiguate the directions of individual principal vectors
|
||||
if disambiguate_directions:
|
||||
|
@ -5,7 +5,6 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.compat import solve
|
||||
|
||||
from .so3 import _so3_exp_map, hat, so3_log_map
|
||||
|
||||
@ -174,7 +173,7 @@ def se3_log_map(
|
||||
# log_translation is V^-1 @ T
|
||||
T = transform[:, 3, :3]
|
||||
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
|
||||
log_translation = solve(V, T[:, :, None])[:, :, 0]
|
||||
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]
|
||||
|
||||
return torch.cat((log_translation, log_rotation), dim=1)
|
||||
|
||||
|
@ -9,7 +9,6 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.common.compat import lstsq
|
||||
from pytorch3d.transforms import acos_linear_extrapolation
|
||||
|
||||
from .common_testing import TestCaseMixin
|
||||
@ -66,7 +65,7 @@ class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase):
|
||||
bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype)
|
||||
# fit a line: slope * x + bias = y
|
||||
x_1 = torch.stack([x, torch.ones_like(x)], dim=-1)
|
||||
slope, bias = lstsq(x_1, y[:, None]).view(-1)[:2]
|
||||
slope, bias = torch.linalg.lstsq(x_1, y[:, None]).solution.view(-1)[:2]
|
||||
desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t**2)
|
||||
# test that the desired slope is the same as the fitted one
|
||||
self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2)
|
||||
|
@ -9,7 +9,6 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.common.compat import qr
|
||||
from pytorch3d.transforms.rotation_conversions import random_rotations
|
||||
from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map
|
||||
from pytorch3d.transforms.so3 import so3_exp_map, so3_log_map, so3_rotation_angle
|
||||
@ -199,7 +198,7 @@ class TestSE3(TestCaseMixin, unittest.TestCase):
|
||||
r = [identity, rot180]
|
||||
r.extend(
|
||||
[
|
||||
qr(identity + torch.randn_like(identity) * 1e-6)[0]
|
||||
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-6)[0]
|
||||
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8
|
||||
# this adds random noise to the second half
|
||||
# of the random orthogonal matrices to generate
|
||||
|
@ -11,7 +11,6 @@ from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.common.compat import qr
|
||||
from pytorch3d.transforms.so3 import (
|
||||
hat,
|
||||
so3_exp_map,
|
||||
@ -49,7 +48,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
# TODO(dnovotny): replace with random_rotation from random_rotation.py
|
||||
rot = []
|
||||
for _ in range(batch_size):
|
||||
r = qr(torch.randn((3, 3), device=device))[0]
|
||||
r = torch.linalg.qr(torch.randn((3, 3), device=device))[0]
|
||||
f = torch.randint(2, (3,), device=device, dtype=torch.float32)
|
||||
if f.sum() % 2 == 0:
|
||||
f = 1 - f
|
||||
@ -145,7 +144,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
# add random rotations and random almost orthonormal matrices
|
||||
r.extend(
|
||||
[
|
||||
qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
|
||||
# this adds random noise to the second half
|
||||
# of the random orthogonal matrices to generate
|
||||
@ -245,7 +244,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
r = [identity, rot180]
|
||||
r.extend(
|
||||
[
|
||||
qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||
for _ in range(batch_size - 2)
|
||||
]
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user