mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Disable gradient calculation in _check_valid_rotation_matrix()
Summary: # Make `transform3d.py` a little bit better (performance and code quality) ## 1. Add decorator `torch.no_grad()` to the function `_check_valid_rotation_matrix()` Function `_check_valid_rotation_matrix()` is needed to identify errors during forward pass only, it's not used for gradients. ## 2. Replace two calls `to` with the single one Reviewed By: bottler Differential Revision: D29656501 fbshipit-source-id: 4419e24dbf436c1b60abf77bda4376fb87a593be
This commit is contained in:
parent
0c02ae907e
commit
2f668ecefe
@ -566,7 +566,7 @@ class Rotate(Transform3d):
|
|||||||
if R.shape[-2:] != (3, 3):
|
if R.shape[-2:] != (3, 3):
|
||||||
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
|
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
|
||||||
raise ValueError(msg % repr(R.shape))
|
raise ValueError(msg % repr(R.shape))
|
||||||
R = R.to(dtype=dtype).to(device=device_)
|
R = R.to(device=device_, dtype=dtype)
|
||||||
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
|
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
|
||||||
N = R.shape[0]
|
N = R.shape[0]
|
||||||
mat = torch.eye(4, dtype=dtype, device=device_)
|
mat = torch.eye(4, dtype=dtype, device=device_)
|
||||||
@ -752,6 +752,9 @@ def _broadcast_bmm(a, b):
|
|||||||
return a.bmm(b)
|
return a.bmm(b)
|
||||||
|
|
||||||
|
|
||||||
|
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
|
||||||
|
# its type `no_grad` is not callable.
|
||||||
|
@torch.no_grad()
|
||||||
def _check_valid_rotation_matrix(R, tol: float = 1e-7):
|
def _check_valid_rotation_matrix(R, tol: float = 1e-7):
|
||||||
"""
|
"""
|
||||||
Determine if R is a valid rotation matrix by checking it satisfies the
|
Determine if R is a valid rotation matrix by checking it satisfies the
|
||||||
|
Loading…
x
Reference in New Issue
Block a user