Omit _check_valid_rotation_matrix by default

Summary:
According to the profiler trace D40326775, _check_valid_rotation_matrix is slow because of aten::all_close operation and _safe_det_3x3 bottlenecks. Disable the check by default unless environment variable PYTORCH3D_CHECK_ROTATION_MATRICES is set to 1.

Comparison after applying the change:
```
Profiling/Function    get_world_to_view (ms)   Transform_points(ms)    specular(ms)
before                12.751                    18.577                  21.384
after                 4.432 (34.7%)             9.248 (49.8%)           11.507 (53.8%)
```

Profiling trace:
https://pxl.cl/2h687
More details in https://docs.google.com/document/d/1kfhEQfpeQToikr5OH9ZssM39CskxWoJ2p8DO5-t6eWk/edit?usp=sharing

Reviewed By: kjchalup

Differential Revision: D40442503

fbshipit-source-id: 954b58de47de235c9d93af441643c22868b547d0
This commit is contained in:
Jiali Duan 2022-10-20 16:05:22 -07:00 committed by Facebook GitHub Bot
parent 8339cf2610
commit 46cb5aaaae
2 changed files with 26 additions and 3 deletions

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import math
import os
import warnings import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
@ -636,7 +637,10 @@ class Rotate(Transform3d):
msg = "R must have shape (3, 3) or (N, 3, 3); got %s" msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
raise ValueError(msg % repr(R.shape)) raise ValueError(msg % repr(R.shape))
R = R.to(device=device_, dtype=dtype) R = R.to(device=device_, dtype=dtype)
_check_valid_rotation_matrix(R, tol=orthogonal_tol) if os.environ.get("PYTORCH3D_CHECK_ROTATION_MATRICES", "0") == "1":
# Note: aten::all_close in the check is computationally slow, so we
# only run the check when PYTORCH3D_CHECK_ROTATION_MATRICES is on.
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
N = R.shape[0] N = R.shape[0]
mat = torch.eye(4, dtype=dtype, device=device_) mat = torch.eye(4, dtype=dtype, device=device_)
mat = mat.view(1, 4, 4).repeat(N, 1, 1) mat = mat.view(1, 4, 4).repeat(N, 1, 1)

View File

@ -4,9 +4,10 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import math
import os
import unittest import unittest
from unittest import mock
import torch import torch
from pytorch3d.transforms import random_rotations from pytorch3d.transforms import random_rotations
@ -191,7 +192,25 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
self.assertTrue(torch.allclose(points_out, points_out_expected)) self.assertTrue(torch.allclose(points_out, points_out_expected))
self.assertTrue(torch.allclose(normals_out, normals_out_expected)) self.assertTrue(torch.allclose(normals_out, normals_out_expected))
def test_rotate(self): @mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "1"}, clear=True)
def test_rotate_check_rot_valid_on(self):
R = so3_exp_map(torch.randn((1, 3)))
t = Transform3d().rotate(R)
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
1, 3, 3
)
normals = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
).view(1, 3, 3)
points_out = t.transform_points(points)
normals_out = t.transform_normals(normals)
points_out_expected = torch.bmm(points, R)
normals_out_expected = torch.bmm(normals, R)
self.assertTrue(torch.allclose(points_out, points_out_expected))
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "0"}, clear=True)
def test_rotate_check_rot_valid_off(self):
R = so3_exp_map(torch.randn((1, 3))) R = so3_exp_map(torch.randn((1, 3)))
t = Transform3d().rotate(R) t = Transform3d().rotate(R)
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(