fbcode/vision/fair/pytorch3d/pytorch3d/transforms/se3.py

Reviewed By: sgrigory

Differential Revision: D93709801

fbshipit-source-id: e4bae81fe1a88fed547304e6e21b248c5a345277
This commit is contained in:
generatedunixname1417043136753450
2026-02-23 14:51:32 -08:00
committed by meta-codesync[bot]
parent e3c80a4368
commit 7b5c78460a

View File

@@ -195,15 +195,15 @@ def _se3_V_matrix(
V = ( V = (
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None] torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
+ log_rotation_hat + log_rotation_hat
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. * ((1 - torch.cos(rotation_angles)) / torch.square(rotation_angles))[
* ((1 - torch.cos(rotation_angles)) / (rotation_angles**2))[:, None, None] :, None, None
]
+ ( + (
log_rotation_hat_square log_rotation_hat_square
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and * (
# `int`. (rotation_angles - torch.sin(rotation_angles))
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles**3))[ / torch.pow(rotation_angles, 3)
:, None, None )[:, None, None]
]
) )
) )
@@ -215,8 +215,7 @@ def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
A helper function that computes the input variables to the `_se3_V_matrix` A helper function that computes the input variables to the `_se3_V_matrix`
function. function.
""" """
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. nrms = torch.square(log_rotation).sum(-1)
nrms = (log_rotation**2).sum(-1)
rotation_angles = torch.clamp(nrms, eps).sqrt() rotation_angles = torch.clamp(nrms, eps).sqrt()
log_rotation_hat = hat(log_rotation) log_rotation_hat = hat(log_rotation)
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat) log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)