From 2f668ecefe9898e2224b35a6608937a44e5d8109 Mon Sep 17 00:00:00 2001 From: Alexey Sidnev Date: Fri, 16 Jul 2021 01:57:28 -0700 Subject: [PATCH] Disable gradient calculation in `_check_valid_rotation_matrix()` Summary: # Make `transform3d.py` a little bit better (performance and code quality) ## 1. Add decorator `torch.no_grad()` to the function `_check_valid_rotation_matrix()` Function `_check_valid_rotation_matrix()` is needed to identify errors during forward pass only, it's not used for gradients. ## 2. Replace two calls `to` with the single one Reviewed By: bottler Differential Revision: D29656501 fbshipit-source-id: 4419e24dbf436c1b60abf77bda4376fb87a593be --- pytorch3d/transforms/transform3d.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index d093af4e..11543044 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -566,7 +566,7 @@ class Rotate(Transform3d): if R.shape[-2:] != (3, 3): msg = "R must have shape (3, 3) or (N, 3, 3); got %s" raise ValueError(msg % repr(R.shape)) - R = R.to(dtype=dtype).to(device=device_) + R = R.to(device=device_, dtype=dtype) _check_valid_rotation_matrix(R, tol=orthogonal_tol) N = R.shape[0] mat = torch.eye(4, dtype=dtype, device=device_) @@ -752,6 +752,9 @@ def _broadcast_bmm(a, b): return a.bmm(b) +# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because +# its type `no_grad` is not callable. +@torch.no_grad() def _check_valid_rotation_matrix(R, tol: float = 1e-7): """ Determine if R is a valid rotation matrix by checking it satisfies the