mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
suppress errors in vision/fair/pytorch3d
Differential Revision: D37172764 fbshipit-source-id: a2ec367e56de2781a17f5e708eb5832ec9d7e6b4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ea4f3260e4
commit
7978ffd1e4
@@ -50,6 +50,7 @@ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
Rotation matrices as tensor of shape (..., 3, 3).
|
||||
"""
|
||||
r, i, j, k = torch.unbind(quaternions, -1)
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
||||
|
||||
o = torch.stack(
|
||||
@@ -131,9 +132,17 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||
# we produce the desired quaternion multiplied by each of r, i, j, k
|
||||
quat_by_rijk = torch.stack(
|
||||
[
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
||||
],
|
||||
dim=-2,
|
||||
@@ -148,7 +157,7 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||
# forall i; we pick the best-conditioned one (with the largest denominator)
|
||||
|
||||
return quat_candidates[
|
||||
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
|
||||
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
||||
].reshape(batch_dim + (4,))
|
||||
|
||||
|
||||
@@ -314,6 +323,7 @@ def random_quaternions(
|
||||
"""
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# pyre-fixme[6]: For 2nd param expected `dtype` but got `Optional[dtype]`.
|
||||
o = torch.randn((n, 4), dtype=dtype, device=device)
|
||||
s = (o * o).sum(1)
|
||||
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
||||
|
||||
@@ -194,9 +194,12 @@ def _se3_V_matrix(
|
||||
V = (
|
||||
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
|
||||
+ log_rotation_hat
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
* ((1 - torch.cos(rotation_angles)) / (rotation_angles**2))[:, None, None]
|
||||
+ (
|
||||
log_rotation_hat_square
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles**3))[
|
||||
:, None, None
|
||||
]
|
||||
@@ -211,6 +214,7 @@ def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
|
||||
A helper function that computes the input variables to the `_se3_V_matrix`
|
||||
function.
|
||||
"""
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
nrms = (log_rotation**2).sum(-1)
|
||||
rotation_angles = torch.clamp(nrms, eps).sqrt()
|
||||
log_rotation_hat = hat(log_rotation)
|
||||
|
||||
@@ -160,6 +160,7 @@ def _so3_exp_map(
|
||||
nrms = (log_rot * log_rot).sum(1)
|
||||
# phis ... rotation angles
|
||||
rot_angles = torch.clamp(nrms, eps).sqrt()
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
rot_angles_inv = 1.0 / rot_angles
|
||||
fac1 = rot_angles_inv * rot_angles.sin()
|
||||
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
|
||||
@@ -167,8 +168,8 @@ def _so3_exp_map(
|
||||
skews_square = torch.bmm(skews, skews)
|
||||
|
||||
R = (
|
||||
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
|
||||
fac1[:, None, None] * skews
|
||||
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
|
||||
+ fac2[:, None, None] * skews_square
|
||||
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
|
||||
)
|
||||
@@ -216,6 +217,7 @@ def so3_log_map(
|
||||
# 2nd order Taylor expansion: phi_factor = 0.5 + (1.0 / 12) * phi**2
|
||||
phi_factor = torch.empty_like(phi)
|
||||
ok_denom = phi_sin.abs() > (0.5 * eps)
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
phi_factor[~ok_denom] = 0.5 + (phi[~ok_denom] ** 2) * (1.0 / 12)
|
||||
phi_factor[ok_denom] = phi[ok_denom] / (2.0 * phi_sin[ok_denom])
|
||||
|
||||
|
||||
@@ -556,7 +556,9 @@ class Scale(Transform3d):
|
||||
Return the inverse of self._matrix.
|
||||
"""
|
||||
xyz = torch.stack([self._matrix[:, i, i] for i in range(4)], dim=1)
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
ixyz = 1.0 / xyz
|
||||
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
|
||||
imat = torch.diag_embed(ixyz, dim1=1, dim2=2)
|
||||
return imat
|
||||
|
||||
|
||||
Reference in New Issue
Block a user