fbcode/vision/fair/pytorch3d/pytorch3d/transforms/se3.py

Reviewed By: sgrigory

Differential Revision: D93709801

fbshipit-source-id: e4bae81fe1a88fed547304e6e21b248c5a345277
This commit is contained in:
generatedunixname1417043136753450
2026-02-23 14:51:32 -08:00
committed by meta-codesync[bot]
parent e3c80a4368
commit 7b5c78460a

View File

@@ -195,15 +195,15 @@ def _se3_V_matrix(
V = (
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
+ log_rotation_hat
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
* ((1 - torch.cos(rotation_angles)) / (rotation_angles**2))[:, None, None]
* ((1 - torch.cos(rotation_angles)) / torch.square(rotation_angles))[
:, None, None
]
+ (
log_rotation_hat_square
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles**3))[
:, None, None
]
* (
(rotation_angles - torch.sin(rotation_angles))
/ torch.pow(rotation_angles, 3)
)[:, None, None]
)
)
@@ -215,8 +215,7 @@ def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
A helper function that computes the input variables to the `_se3_V_matrix`
function.
"""
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
nrms = (log_rotation**2).sum(-1)
nrms = torch.square(log_rotation).sum(-1)
rotation_angles = torch.clamp(nrms, eps).sqrt()
log_rotation_hat = hat(log_rotation)
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)