From b9b5ea34285cd548e5d7554a60703396a43445bb Mon Sep 17 00:00:00 2001 From: generatedunixname1417043136753450 Date: Mon, 23 Feb 2026 12:42:24 -0800 Subject: [PATCH] fbcode/vision/fair/pytorch3d/pytorch3d/common/workaround/symeig3x3.py Reviewed By: sgrigory Differential Revision: D93715209 fbshipit-source-id: 1880a8dd72e35ce5cc93cdeecf770aab6469ca31 --- pytorch3d/common/workaround/symeig3x3.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pytorch3d/common/workaround/symeig3x3.py b/pytorch3d/common/workaround/symeig3x3.py index 03d27359..c7fda114 100644 --- a/pytorch3d/common/workaround/symeig3x3.py +++ b/pytorch3d/common/workaround/symeig3x3.py @@ -82,10 +82,12 @@ class _SymEig3x3(nn.Module): q = inputs_trace / 3.0 # Calculate squared sum of elements outside the main diagonal / 2 - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. - p1 = ((inputs**2).sum(dim=(-1, -2)) - (inputs_diag**2).sum(-1)) / 2 - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. - p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps) + p1 = ( + torch.square(inputs).sum(dim=(-1, -2)) - torch.square(inputs_diag).sum(-1) + ) / 2 + p2 = torch.square(inputs_diag - q[..., None]).sum(dim=-1) + 2.0 * p1.clamp( + self._eps + ) p = torch.sqrt(p2 / 6.0) B = (inputs - q[..., None, None] * self._identity) / p[..., None, None] @@ -104,7 +106,9 @@ class _SymEig3x3(nn.Module): # Soft dispatch between the degenerate case (diagonal A) and general. # diag_soft_cond -> 1.0 when p1 < 6 * eps and diag_soft_cond -> 0.0 otherwise. # We use 6 * eps to take into account the error accumulated during the p1 summation - diag_soft_cond = torch.exp(-((p1 / (6 * self._eps)) ** 2)).detach()[..., None] + diag_soft_cond = torch.exp(-torch.square(p1 / (6 * self._eps))).detach()[ + ..., None + ] # Eigenvalues are the ordered elements of main diagonal in the degenerate case diag_eigenvals, _ = torch.sort(inputs_diag, dim=-1) @@ -199,8 +203,7 @@ class _SymEig3x3(nn.Module): cross_products[..., :1, :] ) - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. - norms_sq = (cross_products**2).sum(dim=-1) + norms_sq = torch.square(cross_products).sum(dim=-1) max_norms_index = norms_sq.argmax(dim=-1) # Pick only the cross-product with highest squared norm for each input