avoid symeig

Summary: Use the newer eigh to avoid deprecation warnings in newer pytorch.

Reviewed By: patricklabatut

Differential Revision: D34375784

fbshipit-source-id: 40efe0d33fdfa071fba80fc97ed008cbfd2ef249
This commit is contained in:
Jeremy Reizenstein 2022-02-21 06:30:25 -08:00 committed by Facebook GitHub Bot
parent 59972b121d
commit db1f7c4506
3 changed files with 17 additions and 6 deletions

View File

@ -49,3 +49,12 @@ def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cove
# PyTorch version >= 1.9
return torch.linalg.qr(A)
return torch.qr(A)
def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
"""
Like torch.linalg.eigh, assuming the argument is a symmetric real matrix.
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
return torch.linalg.eigh(A)
return torch.symeig(A, eigenvalues=True)

View File

@ -16,6 +16,7 @@ from typing import NamedTuple, Optional
import torch
import torch.nn.functional as F
from pytorch3d.common.compat import eigh
from pytorch3d.ops import points_alignment, utils as oputil
@ -105,7 +106,7 @@ def _null_space(m, kernel_dim):
kernel vectors, of size B x kernel_dim
"""
mTm = torch.bmm(m.transpose(1, 2), m)
s, v = torch.symeig(mTm, eigenvectors=True)
s, v = eigh(mTm)
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]

View File

@ -7,8 +7,9 @@
from typing import TYPE_CHECKING, Tuple, Union
import torch
from pytorch3d.common.compat import eigh
from pytorch3d.common.workaround import symeig3x3
from ..common.workaround import symeig3x3
from .utils import convert_pointclouds_to_tensor, get_point_covariances
@ -139,14 +140,14 @@ def estimate_pointcloud_local_coord_frames(
# get the local coord frames as principal directions of
# the per-point covariance
# this is done with torch.symeig, which returns the
# this is done with torch.symeig / torch.linalg.eigh, which returns the
# eigenvectors (=principal directions) in an ascending order of their
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
# corresponds to the normal direction
# corresponding eigenvalues, and the smallest eigenvalue's eigenvector
# corresponds to the normal direction; or with a custom equivalent.
if use_symeig_workaround:
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
else:
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
curvatures, local_coord_frames = eigh(cov)
# disambiguate the directions of individual principal vectors
if disambiguate_directions: