diff --git a/pytorch3d/common/workaround/__init__.py b/pytorch3d/common/workaround/__init__.py new file mode 100644 index 00000000..95ec3422 --- /dev/null +++ b/pytorch3d/common/workaround/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .utils import _safe_det_3x3 +from .symeig3x3 import symeig3x3 diff --git a/pytorch3d/common/workaround/symeig3x3.py b/pytorch3d/common/workaround/symeig3x3.py new file mode 100644 index 00000000..52198d41 --- /dev/null +++ b/pytorch3d/common/workaround/symeig3x3.py @@ -0,0 +1,316 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple, Optional + +import torch +import torch.nn.functional as F +from torch import nn + + +class _SymEig3x3(nn.Module): + """ + Optimized implementation of eigenvalues and eigenvectors computation for symmetric 3x3 + matrices. + + Please see https://en.wikipedia.org/wiki/Eigenvalue_algorithm#3.C3.973_matrices + and https://www.geometrictools.com/Documentation/RobustEigenSymmetric3x3.pdf + """ + + def __init__(self, eps: Optional[float] = None) -> None: + """ + Args: + eps: epsilon to specify, if None then use torch.float eps + """ + super().__init__() + + self.register_buffer("_identity", torch.eye(3)) + self.register_buffer("_rotation_2d", torch.tensor([[0.0, -1.0], [1.0, 0.0]])) + self.register_buffer( + "_rotations_3d", self._create_rotation_matrices(self._rotation_2d) + ) + + self._eps = eps or torch.finfo(torch.float).eps + + @staticmethod + def _create_rotation_matrices(rotation_2d) -> torch.Tensor: + """ + Compute rotations for later use in U V computation + + Args: + rotation_2d: a π/2 rotation matrix. + + Returns: + a (3, 3, 3) tensor containing 3 rotation matrices around each of the coordinate axes + by π/2 + """ + + rotations_3d = torch.zeros((3, 3, 3)) + rotation_axes = set(range(3)) + for rotation_axis in rotation_axes: + rest = list(rotation_axes - {rotation_axis}) + rotations_3d[rotation_axis][rest[0], rest] = rotation_2d[0] + rotations_3d[rotation_axis][rest[1], rest] = rotation_2d[1] + + return rotations_3d + + def forward( + self, inputs: torch.Tensor, eigenvectors: bool = True + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Compute eigenvalues and (optionally) eigenvectors + + Args: + inputs: symmetric matrices with shape of (..., 3, 3) + eigenvectors: whether should we compute only eigenvalues or eigenvectors as well + + Returns: + Either a tuple of (eigenvalues, eigenvectors) or eigenvalues only, depending on + given params. Eigenvalues are of shape (..., 3) and eigenvectors (..., 3, 3) + """ + if inputs.shape[-2:] != (3, 3): + raise ValueError("Only inputs of shape (..., 3, 3) are supported.") + + inputs_diag = inputs.diagonal(dim1=-2, dim2=-1) # pyre-ignore[16] + inputs_trace = inputs_diag.sum(-1) + q = inputs_trace / 3.0 + + # Calculate squared sum of elements outside the main diagonal / 2 + p1 = ((inputs ** 2).sum(dim=(-1, -2)) - (inputs_diag ** 2).sum(-1)) / 2 + p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps) + + p = torch.sqrt(p2 / 6.0) + B = (inputs - q[..., None, None] * self._identity) / p[..., None, None] + + r = torch.det(B) / 2.0 + # Keep r within (-1.0, 1.0) boundaries with a margin to prevent exploding gradients. + r = r.clamp(-1.0 + self._eps, 1.0 - self._eps) + + phi = torch.acos(r) / 3.0 + eig1 = q + 2 * p * torch.cos(phi) + eig2 = q + 2 * p * torch.cos(phi + 2 * math.pi / 3) + eig3 = 3 * q - eig1 - eig2 + # eigenvals[..., i] is the i-th eigenvalue of the input, α0 ≤ α1 ≤ α2. + eigenvals = torch.stack((eig2, eig3, eig1), dim=-1) + + # Soft dispatch between the degenerate case (diagonal A) and general. + # diag_soft_cond -> 1.0 when p1 < 6 * eps and diag_soft_cond -> 0.0 otherwise. + # We use 6 * eps to take into account the error accumulated during the p1 summation + diag_soft_cond = torch.exp(-((p1 / (6 * self._eps)) ** 2)).detach()[..., None] + + # Eigenvalues are the ordered elements of main diagonal in the degenerate case + diag_eigenvals, _ = torch.sort(inputs_diag, dim=-1) + eigenvals = diag_soft_cond * diag_eigenvals + (1.0 - diag_soft_cond) * eigenvals + + if eigenvectors: + eigenvecs = self._construct_eigenvecs_set(inputs, eigenvals) + else: + eigenvecs = None + + return eigenvals, eigenvecs + + def _construct_eigenvecs_set( + self, inputs: torch.Tensor, eigenvals: torch.Tensor + ) -> torch.Tensor: + """ + Construct orthonormal set of eigenvectors by given inputs and pre-computed eigenvalues + + Args: + inputs: tensor of symmetric matrices of shape (..., 3, 3) + eigenvals: tensor of pre-computed eigenvalues of of shape (..., 3, 3) + + Returns: + Tuple of three eigenvector tensors of shape (..., 3, 3), composing an orthonormal + set + """ + eigenvecs_tuple_for_01 = self._construct_eigenvecs( + inputs, eigenvals[..., 0], eigenvals[..., 1] + ) + eigenvecs_for_01 = torch.stack(eigenvecs_tuple_for_01, dim=-1) + + eigenvecs_tuple_for_21 = self._construct_eigenvecs( + inputs, eigenvals[..., 2], eigenvals[..., 1] + ) + eigenvecs_for_21 = torch.stack(eigenvecs_tuple_for_21[::-1], dim=-1) + + # The result will be smooth here even if both parts of comparison + # are close, because eigenvecs_01 and eigenvecs_21 would be mostly equal as well + eigenvecs_cond = ( + eigenvals[..., 1] - eigenvals[..., 0] + > eigenvals[..., 2] - eigenvals[..., 1] + ).detach() + eigenvecs = torch.where( + eigenvecs_cond[..., None, None], eigenvecs_for_01, eigenvecs_for_21 + ) + + return eigenvecs + + def _construct_eigenvecs( + self, inputs: torch.Tensor, alpha0: torch.Tensor, alpha1: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Construct an orthonormal set of eigenvectors by given pair of eigenvalues. + + Args: + inputs: tensor of symmetric matrices of shape (..., 3, 3) + alpha0: first eigenvalues of shape (..., 3) + alpha1: second eigenvalues of shape (..., 3) + + Returns: + Tuple of three eigenvector tensors of shape (..., 3, 3), composing an orthonormal + set + """ + + # Find the eigenvector corresponding to alpha0, its eigenvalue is distinct + ev0 = self._get_ev0(inputs - alpha0[..., None, None] * self._identity) + u, v = self._get_uv(ev0) + ev1 = self._get_ev1(inputs - alpha1[..., None, None] * self._identity, u, v) + # Third eigenvector is computed as the cross-product of the other two + ev2 = torch.cross(ev0, ev1, dim=-1) + + return ev0, ev1, ev2 + + def _get_ev0(self, char_poly: torch.Tensor) -> torch.Tensor: + """ + Construct the first normalized eigenvector given a characteristic polynomial + + Args: + char_poly: a characteristic polynomials of the input matrices of shape (..., 3, 3) + + Returns: + Tensor of first eigenvectors of shape (..., 3) + """ + + r01 = torch.cross(char_poly[..., 0, :], char_poly[..., 1, :], dim=-1) + r12 = torch.cross(char_poly[..., 1, :], char_poly[..., 2, :], dim=-1) + r02 = torch.cross(char_poly[..., 0, :], char_poly[..., 2, :], dim=-1) + + cross_products = torch.stack((r01, r12, r02), dim=-2) + # Regularize it with + or -eps depending on the sign of the first vector + cross_products += self._eps * self._sign_without_zero( + cross_products[..., :1, :] + ) + + norms_sq = (cross_products ** 2).sum(dim=-1) + max_norms_index = norms_sq.argmax(dim=-1) # pyre-ignore[16] + + # Pick only the cross-product with highest squared norm for each input + max_cross_products = self._gather_by_index( + cross_products, max_norms_index[..., None, None], -2 + ) + # Pick corresponding squared norms for each cross-product + max_norms_sq = self._gather_by_index(norms_sq, max_norms_index[..., None], -1) + + # Normalize cross-product vectors by thier norms + return max_cross_products / torch.sqrt(max_norms_sq[..., None]) + + def _gather_by_index( + self, source: torch.Tensor, index: torch.Tensor, dim: int + ) -> torch.Tensor: + """ + Selects elements from the given source tensor by provided index tensor. + Number of dimensions should be the same for source and index tensors. + + Args: + source: input tensor to gather from + index: index tensor with indices to gather from source + dim: dimension to gather across + + Returns: + Tensor of shape same as the source with exception of specified dimension. + """ + + index_shape = list(source.shape) + index_shape[dim] = 1 + + return source.gather(dim, index.expand(index_shape)).squeeze( # pyre-ignore[16] + dim + ) + + def _get_uv(self, w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes unit-length vectors U and V such that {U, V, W} is a right-handed + orthonormal set. + + Args: + w: eigenvector tensor of shape (..., 3) + + Returns: + Tuple of U and V unit-length vector tensors of shape (..., 3) + """ + + min_idx = w.abs().argmin(dim=-1) # pyre-ignore[16] + rotation_2d = self._rotations_3d[min_idx].to(w) + + u = F.normalize((rotation_2d @ w[..., None])[..., 0], dim=-1) + v = torch.cross(w, u, dim=-1) + return u, v + + def _get_ev1( + self, char_poly: torch.Tensor, u: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + """ + Computes the second normalized eigenvector given a characteristic polynomial + and U and V vectors + + Args: + char_poly: a characteristic polynomials of the input matrices of shape (..., 3, 3) + u: unit-length vectors from _get_uv method + v: unit-length vectors from _get_uv method + + Returns: + desc + """ + + j = torch.stack((u, v), dim=-1) + m = j.transpose(-1, -2) @ char_poly @ j + + # If angle between those vectors is acute, take their sum = m[..., 0, :] + m[..., 1, :], + # otherwise take the difference = m[..., 0, :] - m[..., 1, :] + # m is in theory of rank 1 (or 0), so it snaps only when one of the rows is close to 0 + is_acute_sign = self._sign_without_zero( + (m[..., 0, :] * m[..., 1, :]).sum(dim=-1) + ).detach() + + rowspace = m[..., 0, :] + is_acute_sign[..., None] * m[..., 1, :] + # rowspace will be near zero for second-order eigenvalues + # this regularization guarantees abs(rowspace[0]) >= eps in a smooth'ish way + rowspace += self._eps * self._sign_without_zero(rowspace[..., :1]) + + return ( + j + @ F.normalize(rowspace @ self._rotation_2d.to(rowspace), dim=-1)[..., None] + )[..., 0] + + @staticmethod + def _sign_without_zero(tensor): + """ + Args: + tensor: an arbitrary shaped tensor + + Returns: + Tensor of the same shape as an input, but with 1.0 if tensor > 0.0 and -1.0 + otherwise + """ + return 2.0 * (tensor > 0.0).to(tensor.dtype) - 1.0 + + +def symeig3x3( + inputs: torch.Tensor, eigenvectors: bool = True +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Compute eigenvalues and (optionally) eigenvectors + + Args: + inputs: symmetric matrices with shape of (..., 3, 3) + eigenvectors: whether should we compute only eigenvalues or eigenvectors as well + + Returns: + Either a tuple of (eigenvalues, eigenvectors) or eigenvalues only, depending on + given params. Eigenvalues are of shape (..., 3) and eigenvectors (..., 3, 3) + """ + return _SymEig3x3().to(inputs.device)(inputs, eigenvectors=eigenvectors) diff --git a/pytorch3d/common/workaround.py b/pytorch3d/common/workaround/utils.py similarity index 100% rename from pytorch3d/common/workaround.py rename to pytorch3d/common/workaround/utils.py diff --git a/tests/bm_symeig3x3.py b/tests/bm_symeig3x3.py new file mode 100644 index 00000000..7a345462 --- /dev/null +++ b/tests/bm_symeig3x3.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from itertools import product +from typing import Callable, Any + +import torch +from common_testing import get_random_cuda_device +from fvcore.common.benchmark import benchmark +from pytorch3d.common.workaround import symeig3x3 +from test_symeig3x3 import TestSymEig3x3 + +torch.set_num_threads(1) + +CUDA_DEVICE = get_random_cuda_device() + + +def create_traced_func(func, device, batch_size): + traced_func = torch.jit.trace( + func, (TestSymEig3x3.create_random_sym3x3(device, batch_size),) + ) + + return traced_func + + +FUNC_NAME_TO_FUNC = { + "sym3x3eig": (lambda inputs: symeig3x3(inputs, eigenvectors=True)), + "sym3x3eig_traced_cuda": create_traced_func( + (lambda inputs: symeig3x3(inputs, eigenvectors=True)), CUDA_DEVICE, 1024 + ), + "torch_symeig": (lambda inputs: torch.symeig(inputs, eigenvectors=True)), + "torch_linalg_eigh": (lambda inputs: torch.linalg.eigh(inputs)), + "torch_pca_lowrank": ( + lambda inputs: torch.pca_lowrank(inputs, center=False, niter=1) + ), + "sym3x3eig_no_vecs": (lambda inputs: symeig3x3(inputs, eigenvectors=False)), + "torch_symeig_no_vecs": (lambda inputs: torch.symeig(inputs, eigenvectors=False)), + "torch_linalg_eigvalsh_no_vecs": (lambda inputs: torch.linalg.eigvalsh(inputs)), +} + + +def test_symeig3x3(func_name, batch_size, device) -> Callable[[], Any]: + func = FUNC_NAME_TO_FUNC[func_name] + inputs = TestSymEig3x3.create_random_sym3x3(device, batch_size) + torch.cuda.synchronize() + + def symeig3x3(): + func(inputs) + torch.cuda.synchronize() + + return symeig3x3 + + +def bm_symeig3x3() -> None: + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append(CUDA_DEVICE) + + kwargs_list = [] + func_names = FUNC_NAME_TO_FUNC.keys() + batch_sizes = [16, 128, 1024, 8192, 65536, 1048576] + + for func_name, batch_size, device in product(func_names, batch_sizes, devices): + # Run CUDA-only implementations only on GPU + if "cuda" in func_name and not device.startswith("cuda"): + continue + + # Torch built-ins are quite slow on larger batches + if "torch" in func_name and batch_size > 8192: + continue + + # Avoid running CPU implementations on larger batches as well + if device == "cpu" and batch_size > 8192: + continue + + kwargs_list.append( + {"func_name": func_name, "batch_size": batch_size, "device": device} + ) + + benchmark( + test_symeig3x3, + "SYMEIG3X3", + kwargs_list, + warmup_iters=3, + ) + + +if __name__ == "__main__": + bm_symeig3x3() diff --git a/tests/test_symeig3x3.py b/tests/test_symeig3x3.py new file mode 100644 index 00000000..87352929 --- /dev/null +++ b/tests/test_symeig3x3.py @@ -0,0 +1,263 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import torch +from common_testing import TestCaseMixin, get_random_cuda_device +from pytorch3d.common.workaround import symeig3x3 +from pytorch3d.transforms.rotation_conversions import random_rotations + + +class TestSymEig3x3(TestCaseMixin, unittest.TestCase): + TEST_BATCH_SIZE = 1024 + + @staticmethod + def create_random_sym3x3(device, n): + random_3x3 = torch.randn((n, 3, 3), device=device) + random_3x3_T = torch.transpose(random_3x3, 1, 2) + random_sym_3x3 = (random_3x3 * random_3x3_T).contiguous() + + return random_sym_3x3 + + @staticmethod + def create_diag_sym3x3(device, n, noise=0.0): + # Create purly diagonal matrices + random_diag_3x3 = torch.randn((n, 3), device=device).diag_embed() + + # Make them 'almost' diagonal + random_diag_3x3 += noise * TestSymEig3x3.create_random_sym3x3(device, n) + + return random_diag_3x3 + + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + + self._gpu = get_random_cuda_device() + self._cpu = torch.device("cpu") + + def test_is_eigen_gpu(self): + test_input = self.create_random_sym3x3(self._gpu, n=self.TEST_BATCH_SIZE) + + self._test_is_eigen(test_input) + + def test_is_eigen_cpu(self): + test_input = self.create_random_sym3x3(self._cpu, n=self.TEST_BATCH_SIZE) + + self._test_is_eigen(test_input) + + def _test_is_eigen(self, test_input, atol=1e-04, rtol=1e-02): + """ + Verify that values and vectors produced are really eigenvalues and eigenvectors + and can restore the original input matrix with good precision + """ + eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True) + + self.assertClose( + test_input, + eigenvectors @ eigenvalues.diag_embed() @ eigenvectors.transpose(-2, -1), + atol=atol, + rtol=rtol, + ) + + def test_eigenvectors_are_orthonormal_gpu(self): + test_input = self.create_random_sym3x3(self._gpu, n=self.TEST_BATCH_SIZE) + + self._test_eigenvectors_are_orthonormal(test_input) + + def test_eigenvectors_are_orthonormal_cpu(self): + test_input = self.create_random_sym3x3(self._cpu, n=self.TEST_BATCH_SIZE) + + self._test_eigenvectors_are_orthonormal(test_input) + + def _test_eigenvectors_are_orthonormal(self, test_input): + """ + Verify that eigenvectors are an orthonormal set + """ + eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True) + + batched_eye = torch.zeros_like(test_input) + batched_eye[..., :, :] = torch.eye(3, device=batched_eye.device) + + self.assertClose( + batched_eye, eigenvectors @ eigenvectors.transpose(-2, -1), atol=1e-06 + ) + + def test_is_not_nan_or_inf_gpu(self): + test_input = self.create_random_sym3x3(self._gpu, n=self.TEST_BATCH_SIZE) + + self._test_is_not_nan_or_inf(test_input) + + def test_is_not_nan_or_inf_cpu(self): + test_input = self.create_random_sym3x3(self._cpu, n=self.TEST_BATCH_SIZE) + + self._test_is_not_nan_or_inf(test_input) + + def _test_is_not_nan_or_inf(self, test_input): + eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True) + + self.assertTrue(torch.isfinite(eigenvalues).all()) + self.assertTrue(torch.isfinite(eigenvectors).all()) + + def test_degenerate_inputs_gpu(self): + self._test_degenerate_inputs(self._gpu) + + def test_degenerate_inputs_cpu(self): + self._test_degenerate_inputs(self._cpu) + + def _test_degenerate_inputs(self, device): + """ + Test degenerate case when input matrices are diagonal or near-diagonal + """ + + # Purely diagonal case + test_input = self.create_diag_sym3x3(device, self.TEST_BATCH_SIZE) + + self._test_is_not_nan_or_inf(test_input) + self._test_is_eigen(test_input) + self._test_eigenvectors_are_orthonormal(test_input) + + # Almost-diagonal case + test_input = self.create_diag_sym3x3(device, self.TEST_BATCH_SIZE, noise=1e-4) + + self._test_is_not_nan_or_inf(test_input) + self._test_is_eigen(test_input) + self._test_eigenvectors_are_orthonormal(test_input) + + def test_gradients_cpu(self): + self._test_gradients(self._cpu) + + def test_gradients_gpu(self): + self._test_gradients(self._gpu) + + def _test_gradients(self, device): + """ + Tests if gradients pass though without any problems (infs, nans etc) and + also performs gradcheck (compares numerical and analytical gradients) + """ + test_random_input = self.create_random_sym3x3(device, n=16) + test_diag_input = self.create_diag_sym3x3(device, n=16) + test_almost_diag_input = self.create_diag_sym3x3(device, n=16, noise=1e-4) + + test_input = torch.cat( + (test_random_input, test_diag_input, test_almost_diag_input) + ) + test_input.requires_grad = True + + with torch.autograd.detect_anomaly(): + eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True) + + loss = eigenvalues.mean() + eigenvectors.mean() + loss.backward() + + test_random_input.requires_grad = True + # Inputs are converted to double to increase the precision of gradcheck. + torch.autograd.gradcheck( + symeig3x3, test_random_input.double(), eps=1e-6, atol=1e-2, rtol=1e-2 + ) + + def _test_eigenvalues_and_eigenvectors( + self, test_eigenvectors, test_eigenvalues, atol=1e-04, rtol=1e-04 + ): + test_input = ( + test_eigenvectors.transpose(-2, -1) + @ test_eigenvalues.diag_embed() + @ test_eigenvectors + ) + + test_eigenvalues_sorted, _ = torch.sort(test_eigenvalues, dim=-1) + + eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True) + + self.assertClose( + test_eigenvalues_sorted, + eigenvalues, + atol=atol, + rtol=rtol, + ) + + self._test_is_not_nan_or_inf(test_input) + self._test_is_eigen(test_input, atol=atol, rtol=rtol) + self._test_eigenvectors_are_orthonormal(test_input) + + def test_degenerate_eigenvalues_gpu(self): + self._test_degenerate_eigenvalues(self._gpu) + + def test_degenerate_eigenvalues_cpu(self): + self._test_degenerate_eigenvalues(self._cpu) + + def _test_degenerate_eigenvalues(self, device): + """ + Test degenerate eigenvalues like zero-valued and with 2-/3-multiplicity + """ + # Error tolerances for degenerate values are increased as things might become + # numerically unstable + deg_atol = 1e-3 + deg_rtol = 1.0 + + # Construct random orthonormal sets + test_eigenvecs = random_rotations(n=self.TEST_BATCH_SIZE, device=device) + + # Construct random eigenvalues + test_eigenvals = torch.randn( + (self.TEST_BATCH_SIZE, 3), device=test_eigenvecs.device + ) + self._test_eigenvalues_and_eigenvectors( + test_eigenvecs, test_eigenvals, atol=deg_atol, rtol=deg_rtol + ) + + # First eigenvalue is always 0.0 here: [0.0 X Y] + test_eigenvals_with_zero = test_eigenvals.clone() + test_eigenvals_with_zero[..., 0] = 0.0 + self._test_eigenvalues_and_eigenvectors( + test_eigenvecs, test_eigenvals_with_zero, atol=deg_atol, rtol=deg_rtol + ) + + # First two eigenvalues are always the same here: [X X Y] + test_eigenvals_with_multiplicity2 = test_eigenvals.clone() + test_eigenvals_with_multiplicity2[..., 1] = test_eigenvals_with_multiplicity2[ + ..., 0 + ] + self._test_eigenvalues_and_eigenvectors( + test_eigenvecs, + test_eigenvals_with_multiplicity2, + atol=deg_atol, + rtol=deg_rtol, + ) + + # All three eigenvalues are the same here: [X X X] + test_eigenvals_with_multiplicity3 = test_eigenvals_with_multiplicity2.clone() + test_eigenvals_with_multiplicity3[..., 2] = test_eigenvals_with_multiplicity2[ + ..., 0 + ] + self._test_eigenvalues_and_eigenvectors( + test_eigenvecs, + test_eigenvals_with_multiplicity3, + atol=deg_atol, + rtol=deg_rtol, + ) + + def test_more_dimensions(self): + """ + Tests if function supports arbitrary leading dimensions + """ + repeat = 4 + + test_input = self.create_random_sym3x3(self._cpu, n=16) + test_input_4d = test_input[None, ...].expand((repeat,) + test_input.shape) + + eigenvalues, eigenvectors = symeig3x3(test_input, eigenvectors=True) + eigenvalues_4d, eigenvectors_4d = symeig3x3(test_input_4d, eigenvectors=True) + + eigenvalues_4d_gt = eigenvalues[None, ...].expand((repeat,) + eigenvalues.shape) + eigenvectors_4d_gt = eigenvectors[None, ...].expand( + (repeat,) + eigenvectors.shape + ) + + self.assertClose(eigenvalues_4d_gt, eigenvalues_4d) + self.assertClose(eigenvectors_4d_gt, eigenvectors_4d)