mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Symmetric eigen 3x3 implementation + benchmark & tests
Summary: Symmetric eigenvalues 3x3 implementation from https://github.com/fairinternal/denseposeslim/blob/roman_c3dpo/tools/functions.py#L612 based on https://en.wikipedia.org/wiki/Eigenvalue_algorithm#3.C3.973_matrices and https://www.geometrictools.com/Documentation/RobustEigenSymmetric3x3.pdf Benchmarks show significant outperformance of symeig3x3 in comparison with torch implementations (torch.symeig and torch.linalg.eigh) on GPU (P100), especially for large batches: 70-280ns per sample vs 3400ns per sample for torch_linalg_eigh_1048576_cpu It's worth mentioning that torch.linalg.eigh is still comparably fast for batches up to 8192 on CPU. Some tests are still failing as the error thresholds need to be adjusted appropriately. Reviewed By: patricklabatut Differential Revision: D29915453 fbshipit-source-id: 7c1b062da631c57c4e22a42dd0027ea5e205f1b5
This commit is contained in:
parent
9585a58d10
commit
d7d740abe9
8
pytorch3d/common/workaround/__init__.py
Normal file
8
pytorch3d/common/workaround/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .utils import _safe_det_3x3
|
||||||
|
from .symeig3x3 import symeig3x3
|
316
pytorch3d/common/workaround/symeig3x3.py
Normal file
316
pytorch3d/common/workaround/symeig3x3.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class _SymEig3x3(nn.Module):
|
||||||
|
"""
|
||||||
|
Optimized implementation of eigenvalues and eigenvectors computation for symmetric 3x3
|
||||||
|
matrices.
|
||||||
|
|
||||||
|
Please see https://en.wikipedia.org/wiki/Eigenvalue_algorithm#3.C3.973_matrices
|
||||||
|
and https://www.geometrictools.com/Documentation/RobustEigenSymmetric3x3.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eps: Optional[float] = None) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
eps: epsilon to specify, if None then use torch.float eps
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.register_buffer("_identity", torch.eye(3))
|
||||||
|
self.register_buffer("_rotation_2d", torch.tensor([[0.0, -1.0], [1.0, 0.0]]))
|
||||||
|
self.register_buffer(
|
||||||
|
"_rotations_3d", self._create_rotation_matrices(self._rotation_2d)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._eps = eps or torch.finfo(torch.float).eps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_rotation_matrices(rotation_2d) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute rotations for later use in U V computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rotation_2d: a π/2 rotation matrix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a (3, 3, 3) tensor containing 3 rotation matrices around each of the coordinate axes
|
||||||
|
by π/2
|
||||||
|
"""
|
||||||
|
|
||||||
|
rotations_3d = torch.zeros((3, 3, 3))
|
||||||
|
rotation_axes = set(range(3))
|
||||||
|
for rotation_axis in rotation_axes:
|
||||||
|
rest = list(rotation_axes - {rotation_axis})
|
||||||
|
rotations_3d[rotation_axis][rest[0], rest] = rotation_2d[0]
|
||||||
|
rotations_3d[rotation_axis][rest[1], rest] = rotation_2d[1]
|
||||||
|
|
||||||
|
return rotations_3d
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, inputs: torch.Tensor, eigenvectors: bool = True
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Compute eigenvalues and (optionally) eigenvectors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: symmetric matrices with shape of (..., 3, 3)
|
||||||
|
eigenvectors: whether should we compute only eigenvalues or eigenvectors as well
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a tuple of (eigenvalues, eigenvectors) or eigenvalues only, depending on
|
||||||
|
given params. Eigenvalues are of shape (..., 3) and eigenvectors (..., 3, 3)
|
||||||
|
"""
|
||||||
|
if inputs.shape[-2:] != (3, 3):
|
||||||
|
raise ValueError("Only inputs of shape (..., 3, 3) are supported.")
|
||||||
|
|
||||||
|
inputs_diag = inputs.diagonal(dim1=-2, dim2=-1) # pyre-ignore[16]
|
||||||
|
inputs_trace = inputs_diag.sum(-1)
|
||||||
|
q = inputs_trace / 3.0
|
||||||
|
|
||||||
|
# Calculate squared sum of elements outside the main diagonal / 2
|
||||||
|
p1 = ((inputs ** 2).sum(dim=(-1, -2)) - (inputs_diag ** 2).sum(-1)) / 2
|
||||||
|
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
|
||||||
|
|
||||||
|
p = torch.sqrt(p2 / 6.0)
|
||||||
|
B = (inputs - q[..., None, None] * self._identity) / p[..., None, None]
|
||||||
|
|
||||||
|
r = torch.det(B) / 2.0
|
||||||
|
# Keep r within (-1.0, 1.0) boundaries with a margin to prevent exploding gradients.
|
||||||
|
r = r.clamp(-1.0 + self._eps, 1.0 - self._eps)
|
||||||
|
|
||||||
|
phi = torch.acos(r) / 3.0
|
||||||
|
eig1 = q + 2 * p * torch.cos(phi)
|
||||||
|
eig2 = q + 2 * p * torch.cos(phi + 2 * math.pi / 3)
|
||||||
|
eig3 = 3 * q - eig1 - eig2
|
||||||
|
# eigenvals[..., i] is the i-th eigenvalue of the input, α0 ≤ α1 ≤ α2.
|
||||||
|
eigenvals = torch.stack((eig2, eig3, eig1), dim=-1)
|
||||||
|
|
||||||
|
# Soft dispatch between the degenerate case (diagonal A) and general.
|
||||||
|
# diag_soft_cond -> 1.0 when p1 < 6 * eps and diag_soft_cond -> 0.0 otherwise.
|
||||||
|
# We use 6 * eps to take into account the error accumulated during the p1 summation
|
||||||
|
diag_soft_cond = torch.exp(-((p1 / (6 * self._eps)) ** 2)).detach()[..., None]
|
||||||
|
|
||||||
|
# Eigenvalues are the ordered elements of main diagonal in the degenerate case
|
||||||
|
diag_eigenvals, _ = torch.sort(inputs_diag, dim=-1)
|
||||||
|
eigenvals = diag_soft_cond * diag_eigenvals + (1.0 - diag_soft_cond) * eigenvals
|
||||||
|
|
||||||
|
if eigenvectors:
|
||||||
|
eigenvecs = self._construct_eigenvecs_set(inputs, eigenvals)
|
||||||
|
else:
|
||||||
|
eigenvecs = None
|
||||||
|
|
||||||
|
return eigenvals, eigenvecs
|
||||||
|
|
||||||
|
def _construct_eigenvecs_set(
|
||||||
|
self, inputs: torch.Tensor, eigenvals: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Construct orthonormal set of eigenvectors by given inputs and pre-computed eigenvalues
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: tensor of symmetric matrices of shape (..., 3, 3)
|
||||||
|
eigenvals: tensor of pre-computed eigenvalues of of shape (..., 3, 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of three eigenvector tensors of shape (..., 3, 3), composing an orthonormal
|
||||||
|
set
|
||||||
|
"""
|
||||||
|
eigenvecs_tuple_for_01 = self._construct_eigenvecs(
|
||||||
|
inputs, eigenvals[..., 0], eigenvals[..., 1]
|
||||||
|
)
|
||||||
|
eigenvecs_for_01 = torch.stack(eigenvecs_tuple_for_01, dim=-1)
|
||||||
|
|
||||||
|
eigenvecs_tuple_for_21 = self._construct_eigenvecs(
|
||||||
|
inputs, eigenvals[..., 2], eigenvals[..., 1]
|
||||||
|
)
|
||||||
|
eigenvecs_for_21 = torch.stack(eigenvecs_tuple_for_21[::-1], dim=-1)
|
||||||
|
|
||||||
|
# The result will be smooth here even if both parts of comparison
|
||||||
|
# are close, because eigenvecs_01 and eigenvecs_21 would be mostly equal as well
|
||||||
|
eigenvecs_cond = (
|
||||||
|
eigenvals[..., 1] - eigenvals[..., 0]
|
||||||
|
> eigenvals[..., 2] - eigenvals[..., 1]
|
||||||
|
).detach()
|
||||||
|
eigenvecs = torch.where(
|
||||||
|
eigenvecs_cond[..., None, None], eigenvecs_for_01, eigenvecs_for_21
|
||||||
|
)
|
||||||
|
|
||||||
|
return eigenvecs
|
||||||
|
|
||||||
|
def _construct_eigenvecs(
|
||||||
|
self, inputs: torch.Tensor, alpha0: torch.Tensor, alpha1: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Construct an orthonormal set of eigenvectors by given pair of eigenvalues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: tensor of symmetric matrices of shape (..., 3, 3)
|
||||||
|
alpha0: first eigenvalues of shape (..., 3)
|
||||||
|
alpha1: second eigenvalues of shape (..., 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of three eigenvector tensors of shape (..., 3, 3), composing an orthonormal
|
||||||
|
set
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Find the eigenvector corresponding to alpha0, its eigenvalue is distinct
|
||||||
|
ev0 = self._get_ev0(inputs - alpha0[..., None, None] * self._identity)
|
||||||
|
u, v = self._get_uv(ev0)
|
||||||
|
ev1 = self._get_ev1(inputs - alpha1[..., None, None] * self._identity, u, v)
|
||||||
|
# Third eigenvector is computed as the cross-product of the other two
|
||||||
|
ev2 = torch.cross(ev0, ev1, dim=-1)
|
||||||
|
|
||||||
|
return ev0, ev1, ev2
|
||||||
|
|
||||||
|
def _get_ev0(self, char_poly: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Construct the first normalized eigenvector given a characteristic polynomial
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char_poly: a characteristic polynomials of the input matrices of shape (..., 3, 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of first eigenvectors of shape (..., 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
r01 = torch.cross(char_poly[..., 0, :], char_poly[..., 1, :], dim=-1)
|
||||||
|
r12 = torch.cross(char_poly[..., 1, :], char_poly[..., 2, :], dim=-1)
|
||||||
|
r02 = torch.cross(char_poly[..., 0, :], char_poly[..., 2, :], dim=-1)
|
||||||
|
|
||||||
|
cross_products = torch.stack((r01, r12, r02), dim=-2)
|
||||||
|
# Regularize it with + or -eps depending on the sign of the first vector
|
||||||
|
cross_products += self._eps * self._sign_without_zero(
|
||||||
|
cross_products[..., :1, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
norms_sq = (cross_products ** 2).sum(dim=-1)
|
||||||
|
max_norms_index = norms_sq.argmax(dim=-1) # pyre-ignore[16]
|
||||||
|
|
||||||
|
# Pick only the cross-product with highest squared norm for each input
|
||||||
|
max_cross_products = self._gather_by_index(
|
||||||
|
cross_products, max_norms_index[..., None, None], -2
|
||||||
|
)
|
||||||
|
# Pick corresponding squared norms for each cross-product
|
||||||
|
max_norms_sq = self._gather_by_index(norms_sq, max_norms_index[..., None], -1)
|
||||||
|
|
||||||
|
# Normalize cross-product vectors by thier norms
|
||||||
|
return max_cross_products / torch.sqrt(max_norms_sq[..., None])
|
||||||
|
|
||||||
|
def _gather_by_index(
|
||||||
|
self, source: torch.Tensor, index: torch.Tensor, dim: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Selects elements from the given source tensor by provided index tensor.
|
||||||
|
Number of dimensions should be the same for source and index tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: input tensor to gather from
|
||||||
|
index: index tensor with indices to gather from source
|
||||||
|
dim: dimension to gather across
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape same as the source with exception of specified dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
index_shape = list(source.shape)
|
||||||
|
index_shape[dim] = 1
|
||||||
|
|
||||||
|
return source.gather(dim, index.expand(index_shape)).squeeze( # pyre-ignore[16]
|
||||||
|
dim
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_uv(self, w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Computes unit-length vectors U and V such that {U, V, W} is a right-handed
|
||||||
|
orthonormal set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
w: eigenvector tensor of shape (..., 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of U and V unit-length vector tensors of shape (..., 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
min_idx = w.abs().argmin(dim=-1) # pyre-ignore[16]
|
||||||
|
rotation_2d = self._rotations_3d[min_idx].to(w)
|
||||||
|
|
||||||
|
u = F.normalize((rotation_2d @ w[..., None])[..., 0], dim=-1)
|
||||||
|
v = torch.cross(w, u, dim=-1)
|
||||||
|
return u, v
|
||||||
|
|
||||||
|
def _get_ev1(
|
||||||
|
self, char_poly: torch.Tensor, u: torch.Tensor, v: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes the second normalized eigenvector given a characteristic polynomial
|
||||||
|
and U and V vectors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char_poly: a characteristic polynomials of the input matrices of shape (..., 3, 3)
|
||||||
|
u: unit-length vectors from _get_uv method
|
||||||
|
v: unit-length vectors from _get_uv method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
desc
|
||||||
|
"""
|
||||||
|
|
||||||
|
j = torch.stack((u, v), dim=-1)
|
||||||
|
m = j.transpose(-1, -2) @ char_poly @ j
|
||||||
|
|
||||||
|
# If angle between those vectors is acute, take their sum = m[..., 0, :] + m[..., 1, :],
|
||||||
|
# otherwise take the difference = m[..., 0, :] - m[..., 1, :]
|
||||||
|
# m is in theory of rank 1 (or 0), so it snaps only when one of the rows is close to 0
|
||||||
|
is_acute_sign = self._sign_without_zero(
|
||||||
|
(m[..., 0, :] * m[..., 1, :]).sum(dim=-1)
|
||||||
|
).detach()
|
||||||
|
|
||||||
|
rowspace = m[..., 0, :] + is_acute_sign[..., None] * m[..., 1, :]
|
||||||
|
# rowspace will be near zero for second-order eigenvalues
|
||||||
|
# this regularization guarantees abs(rowspace[0]) >= eps in a smooth'ish way
|
||||||
|
rowspace += self._eps * self._sign_without_zero(rowspace[..., :1])
|
||||||
|
|
||||||
|
return (
|
||||||
|
j
|
||||||
|
@ F.normalize(rowspace @ self._rotation_2d.to(rowspace), dim=-1)[..., None]
|
||||||
|
)[..., 0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sign_without_zero(tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tensor: an arbitrary shaped tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of the same shape as an input, but with 1.0 if tensor > 0.0 and -1.0
|
||||||
|
otherwise
|
||||||
|
"""
|
||||||
|
return 2.0 * (tensor > 0.0).to(tensor.dtype) - 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def symeig3x3(
|
||||||
|
inputs: torch.Tensor, eigenvectors: bool = True
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Compute eigenvalues and (optionally) eigenvectors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: symmetric matrices with shape of (..., 3, 3)
|
||||||
|
eigenvectors: whether should we compute only eigenvalues or eigenvectors as well
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a tuple of (eigenvalues, eigenvectors) or eigenvalues only, depending on
|
||||||
|
given params. Eigenvalues are of shape (..., 3) and eigenvectors (..., 3, 3)
|
||||||
|
"""
|
||||||
|
return _SymEig3x3().to(inputs.device)(inputs, eigenvectors=eigenvectors)
|
93
tests/bm_symeig3x3.py
Normal file
93
tests/bm_symeig3x3.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from itertools import product
|
||||||
|
from typing import Callable, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from common_testing import get_random_cuda_device
|
||||||
|
from fvcore.common.benchmark import benchmark
|
||||||
|
from pytorch3d.common.workaround import symeig3x3
|
||||||
|
from test_symeig3x3 import TestSymEig3x3
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
CUDA_DEVICE = get_random_cuda_device()
|
||||||
|
|
||||||
|
|
||||||
|
def create_traced_func(func, device, batch_size):
|
||||||
|
traced_func = torch.jit.trace(
|
||||||
|
func, (TestSymEig3x3.create_random_sym3x3(device, batch_size),)
|
||||||
|
)
|
||||||
|
|
||||||
|
return traced_func
|
||||||
|
|
||||||
|
|
||||||
|
FUNC_NAME_TO_FUNC = {
|
||||||
|
"sym3x3eig": (lambda inputs: symeig3x3(inputs, eigenvectors=True)),
|
||||||
|
"sym3x3eig_traced_cuda": create_traced_func(
|
||||||
|
(lambda inputs: symeig3x3(inputs, eigenvectors=True)), CUDA_DEVICE, 1024
|
||||||
|
),
|
||||||
|
"torch_symeig": (lambda inputs: torch.symeig(inputs, eigenvectors=True)),
|
||||||
|
"torch_linalg_eigh": (lambda inputs: torch.linalg.eigh(inputs)),
|
||||||
|
"torch_pca_lowrank": (
|
||||||
|
lambda inputs: torch.pca_lowrank(inputs, center=False, niter=1)
|
||||||
|
),
|
||||||
|
"sym3x3eig_no_vecs": (lambda inputs: symeig3x3(inputs, eigenvectors=False)),
|
||||||
|
"torch_symeig_no_vecs": (lambda inputs: torch.symeig(inputs, eigenvectors=False)),
|
||||||
|
"torch_linalg_eigvalsh_no_vecs": (lambda inputs: torch.linalg.eigvalsh(inputs)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_symeig3x3(func_name, batch_size, device) -> Callable[[], Any]:
|
||||||
|
func = FUNC_NAME_TO_FUNC[func_name]
|
||||||
|
inputs = TestSymEig3x3.create_random_sym3x3(device, batch_size)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def symeig3x3():
|
||||||
|
func(inputs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return symeig3x3
|
||||||
|
|
||||||
|
|
||||||
|
def bm_symeig3x3() -> None:
|
||||||
|
devices = ["cpu"]
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
devices.append(CUDA_DEVICE)
|
||||||
|
|
||||||
|
kwargs_list = []
|
||||||
|
func_names = FUNC_NAME_TO_FUNC.keys()
|
||||||
|
batch_sizes = [16, 128, 1024, 8192, 65536, 1048576]
|
||||||
|
|
||||||
|
for func_name, batch_size, device in product(func_names, batch_sizes, devices):
|
||||||
|
# Run CUDA-only implementations only on GPU
|
||||||
|
if "cuda" in func_name and not device.startswith("cuda"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Torch built-ins are quite slow on larger batches
|
||||||
|
if "torch" in func_name and batch_size > 8192:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Avoid running CPU implementations on larger batches as well
|
||||||
|
if device == "cpu" and batch_size > 8192:
|
||||||
|
continue
|
||||||
|
|
||||||
|
kwargs_list.append(
|
||||||
|
{"func_name": func_name, "batch_size": batch_size, "device": device}
|
||||||
|
)
|
||||||
|
|
||||||
|
benchmark(
|
||||||
|
test_symeig3x3,
|
||||||
|
"SYMEIG3X3",
|
||||||
|
kwargs_list,
|
||||||
|
warmup_iters=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bm_symeig3x3()
|
263
tests/test_symeig3x3.py
Normal file
263
tests/test_symeig3x3.py
Normal file
@ -0,0 +1,263 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||||
|
from pytorch3d.common.workaround import symeig3x3
|
||||||
|
from pytorch3d.transforms.rotation_conversions import random_rotations
|
||||||
|
|
||||||
|
|
||||||
|
class TestSymEig3x3(TestCaseMixin, unittest.TestCase):
|
||||||
|
TEST_BATCH_SIZE = 1024
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_random_sym3x3(device, n):
|
||||||
|
random_3x3 = torch.randn((n, 3, 3), device=device)
|
||||||
|
random_3x3_T = torch.transpose(random_3x3, 1, 2)
|
||||||
|
random_sym_3x3 = (random_3x3 * random_3x3_T).contiguous()
|
||||||
|
|
||||||
|
return random_sym_3x3
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_diag_sym3x3(device, n, noise=0.0):
|
||||||
|
# Create purly diagonal matrices
|
||||||
|
random_diag_3x3 = torch.randn((n, 3), device=device).diag_embed()
|
||||||
|
|
||||||
|
# Make them 'almost' diagonal
|
||||||
|
random_diag_3x3 += noise * TestSymEig3x3.create_random_sym3x3(device, n)
|
||||||
|
|
||||||
|
return random_diag_3x3
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
self._gpu = get_random_cuda_device()
|
||||||
|
self._cpu = torch.device("cpu")
|
||||||
|
|
||||||
|
def test_is_eigen_gpu(self):
|
||||||
|
test_input = self.create_random_sym3x3(self._gpu, n=self.TEST_BATCH_SIZE)
|
||||||
|
|
||||||
|
self._test_is_eigen(test_input)
|
||||||
|
|
||||||
|
def test_is_eigen_cpu(self):
|
||||||
|
test_input = self.create_random_sym3x3(self._cpu, n=self.TEST_BATCH_SIZE)
|
||||||
|
|
||||||
|
self._test_is_eigen(test_input)
|
||||||
|
|
||||||
|
def _test_is_eigen(self, test_input, atol=1e-04, rtol=1e-02):
|
||||||
|
"""
|
||||||
|
Verify that values and vectors produced are really eigenvalues and eigenvectors
|
||||||
|
and can restore the original input matrix with good precision
|
||||||
|
"""
|
||||||
|
eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True)
|
||||||
|
|
||||||
|
self.assertClose(
|
||||||
|
test_input,
|
||||||
|
eigenvectors @ eigenvalues.diag_embed() @ eigenvectors.transpose(-2, -1),
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_eigenvectors_are_orthonormal_gpu(self):
|
||||||
|
test_input = self.create_random_sym3x3(self._gpu, n=self.TEST_BATCH_SIZE)
|
||||||
|
|
||||||
|
self._test_eigenvectors_are_orthonormal(test_input)
|
||||||
|
|
||||||
|
def test_eigenvectors_are_orthonormal_cpu(self):
|
||||||
|
test_input = self.create_random_sym3x3(self._cpu, n=self.TEST_BATCH_SIZE)
|
||||||
|
|
||||||
|
self._test_eigenvectors_are_orthonormal(test_input)
|
||||||
|
|
||||||
|
def _test_eigenvectors_are_orthonormal(self, test_input):
|
||||||
|
"""
|
||||||
|
Verify that eigenvectors are an orthonormal set
|
||||||
|
"""
|
||||||
|
eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True)
|
||||||
|
|
||||||
|
batched_eye = torch.zeros_like(test_input)
|
||||||
|
batched_eye[..., :, :] = torch.eye(3, device=batched_eye.device)
|
||||||
|
|
||||||
|
self.assertClose(
|
||||||
|
batched_eye, eigenvectors @ eigenvectors.transpose(-2, -1), atol=1e-06
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_is_not_nan_or_inf_gpu(self):
|
||||||
|
test_input = self.create_random_sym3x3(self._gpu, n=self.TEST_BATCH_SIZE)
|
||||||
|
|
||||||
|
self._test_is_not_nan_or_inf(test_input)
|
||||||
|
|
||||||
|
def test_is_not_nan_or_inf_cpu(self):
|
||||||
|
test_input = self.create_random_sym3x3(self._cpu, n=self.TEST_BATCH_SIZE)
|
||||||
|
|
||||||
|
self._test_is_not_nan_or_inf(test_input)
|
||||||
|
|
||||||
|
def _test_is_not_nan_or_inf(self, test_input):
|
||||||
|
eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True)
|
||||||
|
|
||||||
|
self.assertTrue(torch.isfinite(eigenvalues).all())
|
||||||
|
self.assertTrue(torch.isfinite(eigenvectors).all())
|
||||||
|
|
||||||
|
def test_degenerate_inputs_gpu(self):
|
||||||
|
self._test_degenerate_inputs(self._gpu)
|
||||||
|
|
||||||
|
def test_degenerate_inputs_cpu(self):
|
||||||
|
self._test_degenerate_inputs(self._cpu)
|
||||||
|
|
||||||
|
def _test_degenerate_inputs(self, device):
|
||||||
|
"""
|
||||||
|
Test degenerate case when input matrices are diagonal or near-diagonal
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Purely diagonal case
|
||||||
|
test_input = self.create_diag_sym3x3(device, self.TEST_BATCH_SIZE)
|
||||||
|
|
||||||
|
self._test_is_not_nan_or_inf(test_input)
|
||||||
|
self._test_is_eigen(test_input)
|
||||||
|
self._test_eigenvectors_are_orthonormal(test_input)
|
||||||
|
|
||||||
|
# Almost-diagonal case
|
||||||
|
test_input = self.create_diag_sym3x3(device, self.TEST_BATCH_SIZE, noise=1e-4)
|
||||||
|
|
||||||
|
self._test_is_not_nan_or_inf(test_input)
|
||||||
|
self._test_is_eigen(test_input)
|
||||||
|
self._test_eigenvectors_are_orthonormal(test_input)
|
||||||
|
|
||||||
|
def test_gradients_cpu(self):
|
||||||
|
self._test_gradients(self._cpu)
|
||||||
|
|
||||||
|
def test_gradients_gpu(self):
|
||||||
|
self._test_gradients(self._gpu)
|
||||||
|
|
||||||
|
def _test_gradients(self, device):
|
||||||
|
"""
|
||||||
|
Tests if gradients pass though without any problems (infs, nans etc) and
|
||||||
|
also performs gradcheck (compares numerical and analytical gradients)
|
||||||
|
"""
|
||||||
|
test_random_input = self.create_random_sym3x3(device, n=16)
|
||||||
|
test_diag_input = self.create_diag_sym3x3(device, n=16)
|
||||||
|
test_almost_diag_input = self.create_diag_sym3x3(device, n=16, noise=1e-4)
|
||||||
|
|
||||||
|
test_input = torch.cat(
|
||||||
|
(test_random_input, test_diag_input, test_almost_diag_input)
|
||||||
|
)
|
||||||
|
test_input.requires_grad = True
|
||||||
|
|
||||||
|
with torch.autograd.detect_anomaly():
|
||||||
|
eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True)
|
||||||
|
|
||||||
|
loss = eigenvalues.mean() + eigenvectors.mean()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
test_random_input.requires_grad = True
|
||||||
|
# Inputs are converted to double to increase the precision of gradcheck.
|
||||||
|
torch.autograd.gradcheck(
|
||||||
|
symeig3x3, test_random_input.double(), eps=1e-6, atol=1e-2, rtol=1e-2
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_eigenvalues_and_eigenvectors(
|
||||||
|
self, test_eigenvectors, test_eigenvalues, atol=1e-04, rtol=1e-04
|
||||||
|
):
|
||||||
|
test_input = (
|
||||||
|
test_eigenvectors.transpose(-2, -1)
|
||||||
|
@ test_eigenvalues.diag_embed()
|
||||||
|
@ test_eigenvectors
|
||||||
|
)
|
||||||
|
|
||||||
|
test_eigenvalues_sorted, _ = torch.sort(test_eigenvalues, dim=-1)
|
||||||
|
|
||||||
|
eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True)
|
||||||
|
|
||||||
|
self.assertClose(
|
||||||
|
test_eigenvalues_sorted,
|
||||||
|
eigenvalues,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._test_is_not_nan_or_inf(test_input)
|
||||||
|
self._test_is_eigen(test_input, atol=atol, rtol=rtol)
|
||||||
|
self._test_eigenvectors_are_orthonormal(test_input)
|
||||||
|
|
||||||
|
def test_degenerate_eigenvalues_gpu(self):
|
||||||
|
self._test_degenerate_eigenvalues(self._gpu)
|
||||||
|
|
||||||
|
def test_degenerate_eigenvalues_cpu(self):
|
||||||
|
self._test_degenerate_eigenvalues(self._cpu)
|
||||||
|
|
||||||
|
def _test_degenerate_eigenvalues(self, device):
|
||||||
|
"""
|
||||||
|
Test degenerate eigenvalues like zero-valued and with 2-/3-multiplicity
|
||||||
|
"""
|
||||||
|
# Error tolerances for degenerate values are increased as things might become
|
||||||
|
# numerically unstable
|
||||||
|
deg_atol = 1e-3
|
||||||
|
deg_rtol = 1.0
|
||||||
|
|
||||||
|
# Construct random orthonormal sets
|
||||||
|
test_eigenvecs = random_rotations(n=self.TEST_BATCH_SIZE, device=device)
|
||||||
|
|
||||||
|
# Construct random eigenvalues
|
||||||
|
test_eigenvals = torch.randn(
|
||||||
|
(self.TEST_BATCH_SIZE, 3), device=test_eigenvecs.device
|
||||||
|
)
|
||||||
|
self._test_eigenvalues_and_eigenvectors(
|
||||||
|
test_eigenvecs, test_eigenvals, atol=deg_atol, rtol=deg_rtol
|
||||||
|
)
|
||||||
|
|
||||||
|
# First eigenvalue is always 0.0 here: [0.0 X Y]
|
||||||
|
test_eigenvals_with_zero = test_eigenvals.clone()
|
||||||
|
test_eigenvals_with_zero[..., 0] = 0.0
|
||||||
|
self._test_eigenvalues_and_eigenvectors(
|
||||||
|
test_eigenvecs, test_eigenvals_with_zero, atol=deg_atol, rtol=deg_rtol
|
||||||
|
)
|
||||||
|
|
||||||
|
# First two eigenvalues are always the same here: [X X Y]
|
||||||
|
test_eigenvals_with_multiplicity2 = test_eigenvals.clone()
|
||||||
|
test_eigenvals_with_multiplicity2[..., 1] = test_eigenvals_with_multiplicity2[
|
||||||
|
..., 0
|
||||||
|
]
|
||||||
|
self._test_eigenvalues_and_eigenvectors(
|
||||||
|
test_eigenvecs,
|
||||||
|
test_eigenvals_with_multiplicity2,
|
||||||
|
atol=deg_atol,
|
||||||
|
rtol=deg_rtol,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All three eigenvalues are the same here: [X X X]
|
||||||
|
test_eigenvals_with_multiplicity3 = test_eigenvals_with_multiplicity2.clone()
|
||||||
|
test_eigenvals_with_multiplicity3[..., 2] = test_eigenvals_with_multiplicity2[
|
||||||
|
..., 0
|
||||||
|
]
|
||||||
|
self._test_eigenvalues_and_eigenvectors(
|
||||||
|
test_eigenvecs,
|
||||||
|
test_eigenvals_with_multiplicity3,
|
||||||
|
atol=deg_atol,
|
||||||
|
rtol=deg_rtol,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_more_dimensions(self):
|
||||||
|
"""
|
||||||
|
Tests if function supports arbitrary leading dimensions
|
||||||
|
"""
|
||||||
|
repeat = 4
|
||||||
|
|
||||||
|
test_input = self.create_random_sym3x3(self._cpu, n=16)
|
||||||
|
test_input_4d = test_input[None, ...].expand((repeat,) + test_input.shape)
|
||||||
|
|
||||||
|
eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True)
|
||||||
|
eigenvalues_4d, eigenvectors_4d = symeig3x3(test_input_4d, eigenvectors=True)
|
||||||
|
|
||||||
|
eigenvalues_4d_gt = eigenvalues[None, ...].expand((repeat,) + eigenvalues.shape)
|
||||||
|
eigenvectors_4d_gt = eigenvectors[None, ...].expand(
|
||||||
|
(repeat,) + eigenvectors.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertClose(eigenvalues_4d_gt, eigenvalues_4d)
|
||||||
|
self.assertClose(eigenvectors_4d_gt, eigenvectors_4d)
|
Loading…
x
Reference in New Issue
Block a user