suppress errors in vision/fair/pytorch3d

Differential Revision: D37172764

fbshipit-source-id: a2ec367e56de2781a17f5e708eb5832ec9d7e6b4
This commit is contained in:
Pyre Bot Jr
2022-06-15 06:27:35 -07:00
committed by Facebook GitHub Bot
parent ea4f3260e4
commit 7978ffd1e4
61 changed files with 188 additions and 80 deletions

View File

@@ -50,6 +50,7 @@ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
@@ -131,9 +132,17 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
@@ -148,7 +157,7 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
# forall i; we pick the best-conditioned one (with the largest denominator)
return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))
@@ -314,6 +323,7 @@ def random_quaternions(
"""
if isinstance(device, str):
device = torch.device(device)
# pyre-fixme[6]: For 2nd param expected `dtype` but got `Optional[dtype]`.
o = torch.randn((n, 4), dtype=dtype, device=device)
s = (o * o).sum(1)
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]

View File

@@ -194,9 +194,12 @@ def _se3_V_matrix(
V = (
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
+ log_rotation_hat
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
* ((1 - torch.cos(rotation_angles)) / (rotation_angles**2))[:, None, None]
+ (
log_rotation_hat_square
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles**3))[
:, None, None
]
@@ -211,6 +214,7 @@ def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
A helper function that computes the input variables to the `_se3_V_matrix`
function.
"""
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
nrms = (log_rotation**2).sum(-1)
rotation_angles = torch.clamp(nrms, eps).sqrt()
log_rotation_hat = hat(log_rotation)

View File

@@ -160,6 +160,7 @@ def _so3_exp_map(
nrms = (log_rot * log_rot).sum(1)
# phis ... rotation angles
rot_angles = torch.clamp(nrms, eps).sqrt()
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
rot_angles_inv = 1.0 / rot_angles
fac1 = rot_angles_inv * rot_angles.sin()
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
@@ -167,8 +168,8 @@ def _so3_exp_map(
skews_square = torch.bmm(skews, skews)
R = (
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
fac1[:, None, None] * skews
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
+ fac2[:, None, None] * skews_square
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
)
@@ -216,6 +217,7 @@ def so3_log_map(
# 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
phi_factor = torch.empty_like(phi)
ok_denom = phi_sin.abs() > (0.5 * eps)
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12)
phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom])

View File

@@ -556,7 +556,9 @@ class Scale(Transform3d):
Return the inverse of self._matrix.
"""
xyz = torch.stack([self._matrix[:, i, i] for i in range(4)], dim=1)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
ixyz = 1.0 / xyz
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
imat = torch.diag_embed(ixyz, dim1=1, dim2=2)
return imat