Disable gradient calculation in _check_valid_rotation_matrix()

Summary:
# Make `transform3d.py` a little bit better (performance and code quality)

## 1. Add decorator `torch.no_grad()` to the function `_check_valid_rotation_matrix()`

Function `_check_valid_rotation_matrix()` is needed to identify errors during forward pass only, it's not used for gradients.

## 2. Replace two calls `to` with the single one

Reviewed By: bottler

Differential Revision: D29656501

fbshipit-source-id: 4419e24dbf436c1b60abf77bda4376fb87a593be
This commit is contained in:
Alexey Sidnev 2021-07-16 01:57:28 -07:00 committed by Facebook GitHub Bot
parent 0c02ae907e
commit 2f668ecefe

View File

@ -566,7 +566,7 @@ class Rotate(Transform3d):
if R.shape[-2:] != (3, 3):
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
raise ValueError(msg % repr(R.shape))
R = R.to(dtype=dtype).to(device=device_)
R = R.to(device=device_, dtype=dtype)
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
N = R.shape[0]
mat = torch.eye(4, dtype=dtype, device=device_)
@ -752,6 +752,9 @@ def _broadcast_bmm(a, b):
return a.bmm(b)
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
# its type `no_grad` is not callable.
@torch.no_grad()
def _check_valid_rotation_matrix(R, tol: float = 1e-7):
"""
Determine if R is a valid rotation matrix by checking it satisfies the