fbcode/vision/fair/pytorch3d/pytorch3d/common/workaround/symeig3x3.py

Reviewed By: sgrigory

Differential Revision: D93715209

fbshipit-source-id: 1880a8dd72e35ce5cc93cdeecf770aab6469ca31
This commit is contained in:
generatedunixname1417043136753450
2026-02-23 12:42:24 -08:00
committed by meta-codesync[bot]
parent 0e435c297c
commit b9b5ea3428

View File

@@ -82,10 +82,12 @@ class _SymEig3x3(nn.Module):
q = inputs_trace / 3.0
# Calculate squared sum of elements outside the main diagonal / 2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
p1 = ((inputs**2).sum(dim=(-1, -2)) - (inputs_diag**2).sum(-1)) / 2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
p1 = (
torch.square(inputs).sum(dim=(-1, -2)) - torch.square(inputs_diag).sum(-1)
) / 2
p2 = torch.square(inputs_diag - q[..., None]).sum(dim=-1) + 2.0 * p1.clamp(
self._eps
)
p = torch.sqrt(p2 / 6.0)
B = (inputs - q[..., None, None] * self._identity) / p[..., None, None]
@@ -104,7 +106,9 @@ class _SymEig3x3(nn.Module):
# Soft dispatch between the degenerate case (diagonal A) and general.
# diag_soft_cond -> 1.0 when p1 < 6 * eps and diag_soft_cond -> 0.0 otherwise.
# We use 6 * eps to take into account the error accumulated during the p1 summation
diag_soft_cond = torch.exp(-((p1 / (6 * self._eps)) ** 2)).detach()[..., None]
diag_soft_cond = torch.exp(-torch.square(p1 / (6 * self._eps))).detach()[
..., None
]
# Eigenvalues are the ordered elements of main diagonal in the degenerate case
diag_eigenvals, _ = torch.sort(inputs_diag, dim=-1)
@@ -199,8 +203,7 @@ class _SymEig3x3(nn.Module):
cross_products[..., :1, :]
)
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
norms_sq = (cross_products**2).sum(dim=-1)
norms_sq = torch.square(cross_products).sum(dim=-1)
max_norms_index = norms_sq.argmax(dim=-1)
# Pick only the cross-product with highest squared norm for each input