diff --git a/pytorch3d/common/compat.py b/pytorch3d/common/compat.py index 278ecb24..e15a6107 100644 --- a/pytorch3d/common/compat.py +++ b/pytorch3d/common/compat.py @@ -49,3 +49,12 @@ def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cove # PyTorch version >= 1.9 return torch.linalg.qr(A) return torch.qr(A) + + +def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover + """ + Like torch.linalg.eigh, assuming the argument is a symmetric real matrix. + """ + if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): + return torch.linalg.eigh(A) + return torch.symeig(A, eigenvalues=True) diff --git a/pytorch3d/ops/perspective_n_points.py b/pytorch3d/ops/perspective_n_points.py index 017c345b..92cf4dfb 100644 --- a/pytorch3d/ops/perspective_n_points.py +++ b/pytorch3d/ops/perspective_n_points.py @@ -16,6 +16,7 @@ from typing import NamedTuple, Optional import torch import torch.nn.functional as F +from pytorch3d.common.compat import eigh from pytorch3d.ops import points_alignment, utils as oputil @@ -105,7 +106,7 @@ def _null_space(m, kernel_dim): kernel vectors, of size B x kernel_dim """ mTm = torch.bmm(m.transpose(1, 2), m) - s, v = torch.symeig(mTm, eigenvectors=True) + s, v = eigh(mTm) return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim] diff --git a/pytorch3d/ops/points_normals.py b/pytorch3d/ops/points_normals.py index 702c0bb7..70128814 100644 --- a/pytorch3d/ops/points_normals.py +++ b/pytorch3d/ops/points_normals.py @@ -7,8 +7,9 @@ from typing import TYPE_CHECKING, Tuple, Union import torch +from pytorch3d.common.compat import eigh +from pytorch3d.common.workaround import symeig3x3 -from ..common.workaround import symeig3x3 from .utils import convert_pointclouds_to_tensor, get_point_covariances @@ -139,14 +140,14 @@ def estimate_pointcloud_local_coord_frames( # get the local coord frames as principal directions of # the per-point covariance - # this is done with torch.symeig, which returns the + # this is done with torch.symeig / torch.linalg.eigh, which returns the # eigenvectors (=principal directions) in an ascending order of their - # corresponding eigenvalues, while the smallest eigenvalue's eigenvector - # corresponds to the normal direction + # corresponding eigenvalues, and the smallest eigenvalue's eigenvector + # corresponds to the normal direction; or with a custom equivalent. if use_symeig_workaround: curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True) else: - curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True) + curvatures, local_coord_frames = eigh(cov) # disambiguate the directions of individual principal vectors if disambiguate_directions: