suppress errors in vision/fair/pytorch3d

Differential Revision: D37172764

fbshipit-source-id: a2ec367e56de2781a17f5e708eb5832ec9d7e6b4
This commit is contained in:
Pyre Bot Jr
2022-06-15 06:27:35 -07:00
committed by Facebook GitHub Bot
parent ea4f3260e4
commit 7978ffd1e4
61 changed files with 188 additions and 80 deletions

View File

@@ -75,12 +75,14 @@ class _SymEig3x3(nn.Module):
if inputs.shape[-2:] != (3, 3):
raise ValueError("Only inputs of shape (..., 3, 3) are supported.")
inputs_diag = inputs.diagonal(dim1=-2, dim2=-1) # pyre-ignore[16]
inputs_diag = inputs.diagonal(dim1=-2, dim2=-1)
inputs_trace = inputs_diag.sum(-1)
q = inputs_trace / 3.0
# Calculate squared sum of elements outside the main diagonal / 2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
p1 = ((inputs**2).sum(dim=(-1, -2)) - (inputs_diag**2).sum(-1)) / 2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
p = torch.sqrt(p2 / 6.0)
@@ -195,8 +197,9 @@ class _SymEig3x3(nn.Module):
cross_products[..., :1, :]
)
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
norms_sq = (cross_products**2).sum(dim=-1)
max_norms_index = norms_sq.argmax(dim=-1) # pyre-ignore[16]
max_norms_index = norms_sq.argmax(dim=-1)
# Pick only the cross-product with highest squared norm for each input
max_cross_products = self._gather_by_index(
@@ -227,9 +230,7 @@ class _SymEig3x3(nn.Module):
index_shape = list(source.shape)
index_shape[dim] = 1
return source.gather(dim, index.expand(index_shape)).squeeze( # pyre-ignore[16]
dim
)
return source.gather(dim, index.expand(index_shape)).squeeze(dim)
def _get_uv(self, w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@@ -243,7 +244,7 @@ class _SymEig3x3(nn.Module):
Tuple of U and V unit-length vector tensors of shape (..., 3)
"""
min_idx = w.abs().argmin(dim=-1) # pyre-ignore[16]
min_idx = w.abs().argmin(dim=-1)
rotation_2d = self._rotations_3d[min_idx].to(w)
u = F.normalize((rotation_2d @ w[..., None])[..., 0], dim=-1)