mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
avoid symeig
Summary: Use the newer eigh to avoid deprecation warnings in newer pytorch. Reviewed By: patricklabatut Differential Revision: D34375784 fbshipit-source-id: 40efe0d33fdfa071fba80fc97ed008cbfd2ef249
This commit is contained in:
parent
59972b121d
commit
db1f7c4506
@ -49,3 +49,12 @@ def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cove
|
|||||||
# PyTorch version >= 1.9
|
# PyTorch version >= 1.9
|
||||||
return torch.linalg.qr(A)
|
return torch.linalg.qr(A)
|
||||||
return torch.qr(A)
|
return torch.qr(A)
|
||||||
|
|
||||||
|
|
||||||
|
def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
|
||||||
|
"""
|
||||||
|
Like torch.linalg.eigh, assuming the argument is a symmetric real matrix.
|
||||||
|
"""
|
||||||
|
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
|
||||||
|
return torch.linalg.eigh(A)
|
||||||
|
return torch.symeig(A, eigenvalues=True)
|
||||||
|
@ -16,6 +16,7 @@ from typing import NamedTuple, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from pytorch3d.common.compat import eigh
|
||||||
from pytorch3d.ops import points_alignment, utils as oputil
|
from pytorch3d.ops import points_alignment, utils as oputil
|
||||||
|
|
||||||
|
|
||||||
@ -105,7 +106,7 @@ def _null_space(m, kernel_dim):
|
|||||||
kernel vectors, of size B x kernel_dim
|
kernel vectors, of size B x kernel_dim
|
||||||
"""
|
"""
|
||||||
mTm = torch.bmm(m.transpose(1, 2), m)
|
mTm = torch.bmm(m.transpose(1, 2), m)
|
||||||
s, v = torch.symeig(mTm, eigenvectors=True)
|
s, v = eigh(mTm)
|
||||||
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
|
return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,8 +7,9 @@
|
|||||||
from typing import TYPE_CHECKING, Tuple, Union
|
from typing import TYPE_CHECKING, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.common.compat import eigh
|
||||||
|
from pytorch3d.common.workaround import symeig3x3
|
||||||
|
|
||||||
from ..common.workaround import symeig3x3
|
|
||||||
from .utils import convert_pointclouds_to_tensor, get_point_covariances
|
from .utils import convert_pointclouds_to_tensor, get_point_covariances
|
||||||
|
|
||||||
|
|
||||||
@ -139,14 +140,14 @@ def estimate_pointcloud_local_coord_frames(
|
|||||||
|
|
||||||
# get the local coord frames as principal directions of
|
# get the local coord frames as principal directions of
|
||||||
# the per-point covariance
|
# the per-point covariance
|
||||||
# this is done with torch.symeig, which returns the
|
# this is done with torch.symeig / torch.linalg.eigh, which returns the
|
||||||
# eigenvectors (=principal directions) in an ascending order of their
|
# eigenvectors (=principal directions) in an ascending order of their
|
||||||
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
|
# corresponding eigenvalues, and the smallest eigenvalue's eigenvector
|
||||||
# corresponds to the normal direction
|
# corresponds to the normal direction; or with a custom equivalent.
|
||||||
if use_symeig_workaround:
|
if use_symeig_workaround:
|
||||||
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
|
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
|
||||||
else:
|
else:
|
||||||
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
|
curvatures, local_coord_frames = eigh(cov)
|
||||||
|
|
||||||
# disambiguate the directions of individual principal vectors
|
# disambiguate the directions of individual principal vectors
|
||||||
if disambiguate_directions:
|
if disambiguate_directions:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user