diff --git a/pytorch3d/transforms/se3.py b/pytorch3d/transforms/se3.py index 693b84c2..dab69d02 100644 --- a/pytorch3d/transforms/se3.py +++ b/pytorch3d/transforms/se3.py @@ -195,15 +195,15 @@ def _se3_V_matrix( V = ( torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None] + log_rotation_hat - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. - * ((1 - torch.cos(rotation_angles)) / (rotation_angles**2))[:, None, None] + * ((1 - torch.cos(rotation_angles)) / torch.square(rotation_angles))[ + :, None, None + ] + ( log_rotation_hat_square - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - * ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles**3))[ - :, None, None - ] + * ( + (rotation_angles - torch.sin(rotation_angles)) + / torch.pow(rotation_angles, 3) + )[:, None, None] ) ) @@ -215,8 +215,7 @@ def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4): A helper function that computes the input variables to the `_se3_V_matrix` function. """ - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. - nrms = (log_rotation**2).sum(-1) + nrms = torch.square(log_rotation).sum(-1) rotation_angles = torch.clamp(nrms, eps).sqrt() log_rotation_hat = hat(log_rotation) log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)