mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-07 04:36:00 +08:00
suppress errors in vision/fair/pytorch3d
Differential Revision: D37172764 fbshipit-source-id: a2ec367e56de2781a17f5e708eb5832ec9d7e6b4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ea4f3260e4
commit
7978ffd1e4
@@ -75,12 +75,14 @@ class _SymEig3x3(nn.Module):
|
||||
if inputs.shape[-2:] != (3, 3):
|
||||
raise ValueError("Only inputs of shape (..., 3, 3) are supported.")
|
||||
|
||||
inputs_diag = inputs.diagonal(dim1=-2, dim2=-1) # pyre-ignore[16]
|
||||
inputs_diag = inputs.diagonal(dim1=-2, dim2=-1)
|
||||
inputs_trace = inputs_diag.sum(-1)
|
||||
q = inputs_trace / 3.0
|
||||
|
||||
# Calculate squared sum of elements outside the main diagonal / 2
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
p1 = ((inputs**2).sum(dim=(-1, -2)) - (inputs_diag**2).sum(-1)) / 2
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
|
||||
|
||||
p = torch.sqrt(p2 / 6.0)
|
||||
@@ -195,8 +197,9 @@ class _SymEig3x3(nn.Module):
|
||||
cross_products[..., :1, :]
|
||||
)
|
||||
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
norms_sq = (cross_products**2).sum(dim=-1)
|
||||
max_norms_index = norms_sq.argmax(dim=-1) # pyre-ignore[16]
|
||||
max_norms_index = norms_sq.argmax(dim=-1)
|
||||
|
||||
# Pick only the cross-product with highest squared norm for each input
|
||||
max_cross_products = self._gather_by_index(
|
||||
@@ -227,9 +230,7 @@ class _SymEig3x3(nn.Module):
|
||||
index_shape = list(source.shape)
|
||||
index_shape[dim] = 1
|
||||
|
||||
return source.gather(dim, index.expand(index_shape)).squeeze( # pyre-ignore[16]
|
||||
dim
|
||||
)
|
||||
return source.gather(dim, index.expand(index_shape)).squeeze(dim)
|
||||
|
||||
def _get_uv(self, w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -243,7 +244,7 @@ class _SymEig3x3(nn.Module):
|
||||
Tuple of U and V unit-length vector tensors of shape (..., 3)
|
||||
"""
|
||||
|
||||
min_idx = w.abs().argmin(dim=-1) # pyre-ignore[16]
|
||||
min_idx = w.abs().argmin(dim=-1)
|
||||
rotation_2d = self._rotations_3d[min_idx].to(w)
|
||||
|
||||
u = F.normalize((rotation_2d @ w[..., None])[..., 0], dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user