mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
avoid symeig
Summary: Use the newer eigh to avoid deprecation warnings in newer pytorch. Reviewed By: patricklabatut Differential Revision: D34375784 fbshipit-source-id: 40efe0d33fdfa071fba80fc97ed008cbfd2ef249
This commit is contained in:
parent
59972b121d
commit
db1f7c4506
@ -49,3 +49,12 @@ def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cove
|
||||
# PyTorch version >= 1.9
|
||||
return torch.linalg.qr(A)
|
||||
return torch.qr(A)
|
||||
|
||||
|
||||
def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
|
||||
"""
|
||||
Like torch.linalg.eigh, assuming the argument is a symmetric real matrix.
|
||||
"""
|
||||
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
|
||||
return torch.linalg.eigh(A)
|
||||
return torch.symeig(A, eigenvalues=True)
|
||||
|
@ -16,6 +16,7 @@ from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.common.compat import eigh
|
||||
from pytorch3d.ops import points_alignment, utils as oputil
|
||||
|
||||
|
||||
@ -105,7 +106,7 @@ def _null_space(m, kernel_dim):
|
||||
kernel vectors, of size B x kernel_dim
|
||||
"""
|
||||
mTm = torch.bmm(m.transpose(1, 2), m)
|
||||
s, v = torch.symeig(mTm, eigenvectors=True)
|
||||
s, v = eigh(mTm)
|
||||
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
|
||||
|
||||
|
||||
|
@ -7,8 +7,9 @@
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.compat import eigh
|
||||
from pytorch3d.common.workaround import symeig3x3
|
||||
|
||||
from ..common.workaround import symeig3x3
|
||||
from .utils import convert_pointclouds_to_tensor, get_point_covariances
|
||||
|
||||
|
||||
@ -139,14 +140,14 @@ def estimate_pointcloud_local_coord_frames(
|
||||
|
||||
# get the local coord frames as principal directions of
|
||||
# the per-point covariance
|
||||
# this is done with torch.symeig, which returns the
|
||||
# this is done with torch.symeig / torch.linalg.eigh, which returns the
|
||||
# eigenvectors (=principal directions) in an ascending order of their
|
||||
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
|
||||
# corresponds to the normal direction
|
||||
# corresponding eigenvalues, and the smallest eigenvalue's eigenvector
|
||||
# corresponds to the normal direction; or with a custom equivalent.
|
||||
if use_symeig_workaround:
|
||||
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
|
||||
else:
|
||||
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
|
||||
curvatures, local_coord_frames = eigh(cov)
|
||||
|
||||
# disambiguate the directions of individual principal vectors
|
||||
if disambiguate_directions:
|
||||
|
Loading…
x
Reference in New Issue
Block a user