diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index d093af4e..11543044 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -566,7 +566,7 @@ class Rotate(Transform3d): if R.shape[-2:] != (3, 3): msg = "R must have shape (3, 3) or (N, 3, 3); got %s" raise ValueError(msg % repr(R.shape)) - R = R.to(dtype=dtype).to(device=device_) + R = R.to(device=device_, dtype=dtype) _check_valid_rotation_matrix(R, tol=orthogonal_tol) N = R.shape[0] mat = torch.eye(4, dtype=dtype, device=device_) @@ -752,6 +752,9 @@ def _broadcast_bmm(a, b): return a.bmm(b) +# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because +# its type `no_grad` is not callable. +@torch.no_grad() def _check_valid_rotation_matrix(R, tol: float = 1e-7): """ Determine if R is a valid rotation matrix by checking it satisfies the